Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
786
vendor/ruvector/crates/ruvector-solver/src/validation.rs
vendored
Normal file
786
vendor/ruvector/crates/ruvector-solver/src/validation.rs
vendored
Normal file
@@ -0,0 +1,786 @@
|
||||
//! Comprehensive input validation for solver operations.
|
||||
//!
|
||||
//! All validation functions run eagerly before any computation begins, ensuring
|
||||
//! callers receive clear diagnostics instead of mysterious numerical failures or
|
||||
//! resource exhaustion. Every public function returns [`ValidationError`] on
|
||||
//! failure, which converts into [`SolverError::InvalidInput`] via `From`.
|
||||
//!
|
||||
//! # Limits
|
||||
//!
|
||||
//! Hard limits are enforced to prevent denial-of-service through oversized
|
||||
//! inputs:
|
||||
//!
|
||||
//! | Resource | Limit | Constant |
|
||||
//! |---------------|------------------------|-------------------|
|
||||
//! | Nodes (rows) | 10,000,000 | [`MAX_NODES`] |
|
||||
//! | Edges (nnz) | 100,000,000 | [`MAX_EDGES`] |
|
||||
//! | Dimension | 65,536 | [`MAX_DIM`] |
|
||||
//! | Iterations | 1,000,000 | [`MAX_ITERATIONS`]|
|
||||
//! | Request body | 10 MiB | [`MAX_BODY_SIZE`] |
|
||||
|
||||
use crate::error::ValidationError;
|
||||
use crate::types::{CsrMatrix, SolverResult};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Resource limits
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Maximum number of rows or columns to prevent resource exhaustion.
|
||||
pub const MAX_NODES: usize = 10_000_000;
|
||||
|
||||
/// Maximum number of non-zero entries.
|
||||
pub const MAX_EDGES: usize = 100_000_000;
|
||||
|
||||
/// Maximum vector/matrix dimension for dense operations.
|
||||
pub const MAX_DIM: usize = 65_536;
|
||||
|
||||
/// Maximum solver iterations to prevent runaway computation.
|
||||
pub const MAX_ITERATIONS: usize = 1_000_000;
|
||||
|
||||
/// Maximum request body size in bytes (10 MiB).
|
||||
pub const MAX_BODY_SIZE: usize = 10 * 1024 * 1024;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// CSR matrix validation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Validate the structural integrity of a CSR matrix.
|
||||
///
|
||||
/// Performs the following checks in order:
|
||||
///
|
||||
/// 1. `rows` and `cols` are within [`MAX_NODES`].
|
||||
/// 2. `nnz` (number of non-zeros) is within [`MAX_EDGES`].
|
||||
/// 3. `row_ptr` length equals `rows + 1`.
|
||||
/// 4. `row_ptr` is monotonically non-decreasing.
|
||||
/// 5. `row_ptr[0] == 0` and `row_ptr[rows] == nnz`.
|
||||
/// 6. `col_indices` length equals `values` length.
|
||||
/// 7. All column indices are less than `cols`.
|
||||
/// 8. No `NaN` or `Inf` values in `values`.
|
||||
/// 9. Column indices are sorted within each row (emits a [`tracing::warn`] if
|
||||
/// not, but does not error).
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`ValidationError`] describing the first violation found.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use ruvector_solver::types::CsrMatrix;
|
||||
/// use ruvector_solver::validation::validate_csr_matrix;
|
||||
///
|
||||
/// let m = CsrMatrix::<f32>::from_coo(2, 2, vec![(0, 0, 1.0), (1, 1, 2.0)]);
|
||||
/// assert!(validate_csr_matrix(&m).is_ok());
|
||||
/// ```
|
||||
pub fn validate_csr_matrix(matrix: &CsrMatrix<f32>) -> Result<(), ValidationError> {
|
||||
// 1. Dimension bounds
|
||||
if matrix.rows > MAX_NODES || matrix.cols > MAX_NODES {
|
||||
return Err(ValidationError::MatrixTooLarge {
|
||||
rows: matrix.rows,
|
||||
cols: matrix.cols,
|
||||
max_dim: MAX_NODES,
|
||||
});
|
||||
}
|
||||
|
||||
// 2. NNZ bounds
|
||||
let nnz = matrix.values.len();
|
||||
if nnz > MAX_EDGES {
|
||||
return Err(ValidationError::DimensionMismatch(format!(
|
||||
"nnz {} exceeds maximum allowed {}",
|
||||
nnz, MAX_EDGES,
|
||||
)));
|
||||
}
|
||||
|
||||
// 3. row_ptr length
|
||||
let expected_row_ptr_len = matrix.rows + 1;
|
||||
if matrix.row_ptr.len() != expected_row_ptr_len {
|
||||
return Err(ValidationError::DimensionMismatch(format!(
|
||||
"row_ptr length {} does not equal rows + 1 = {}",
|
||||
matrix.row_ptr.len(),
|
||||
expected_row_ptr_len,
|
||||
)));
|
||||
}
|
||||
|
||||
// 4. row_ptr monotonicity
|
||||
for i in 1..matrix.row_ptr.len() {
|
||||
if matrix.row_ptr[i] < matrix.row_ptr[i - 1] {
|
||||
return Err(ValidationError::NonMonotonicRowPtrs { position: i });
|
||||
}
|
||||
}
|
||||
|
||||
// 5. row_ptr boundary values
|
||||
if matrix.row_ptr[0] != 0 {
|
||||
return Err(ValidationError::DimensionMismatch(format!(
|
||||
"row_ptr[0] = {} (expected 0)",
|
||||
matrix.row_ptr[0],
|
||||
)));
|
||||
}
|
||||
let expected_nnz = matrix.row_ptr[matrix.rows];
|
||||
if expected_nnz != nnz {
|
||||
return Err(ValidationError::DimensionMismatch(format!(
|
||||
"values length {} does not match row_ptr[rows] = {}",
|
||||
nnz, expected_nnz,
|
||||
)));
|
||||
}
|
||||
|
||||
// 6. col_indices length must match values length
|
||||
if matrix.col_indices.len() != nnz {
|
||||
return Err(ValidationError::DimensionMismatch(format!(
|
||||
"col_indices length {} does not match values length {}",
|
||||
matrix.col_indices.len(),
|
||||
nnz,
|
||||
)));
|
||||
}
|
||||
|
||||
// 7. Column index bounds + 9. Sorted check (warn only) + 8. Finiteness
|
||||
for row in 0..matrix.rows {
|
||||
let start = matrix.row_ptr[row];
|
||||
let end = matrix.row_ptr[row + 1];
|
||||
|
||||
let mut prev_col: Option<usize> = None;
|
||||
for idx in start..end {
|
||||
let col = matrix.col_indices[idx];
|
||||
if col >= matrix.cols {
|
||||
return Err(ValidationError::IndexOutOfBounds {
|
||||
index: col as u32,
|
||||
row,
|
||||
cols: matrix.cols,
|
||||
});
|
||||
}
|
||||
|
||||
let val = matrix.values[idx];
|
||||
if !val.is_finite() {
|
||||
return Err(ValidationError::NonFiniteValue(format!(
|
||||
"matrix[{}, {}] = {}",
|
||||
row, col, val,
|
||||
)));
|
||||
}
|
||||
|
||||
// Check sorted order within row (warn, not error)
|
||||
if let Some(pc) = prev_col {
|
||||
if col < pc {
|
||||
tracing::warn!(
|
||||
row = row,
|
||||
"column indices not sorted within row (col {} follows {}); \
|
||||
performance may be degraded",
|
||||
col,
|
||||
pc,
|
||||
);
|
||||
}
|
||||
}
|
||||
prev_col = Some(col);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// RHS vector validation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Validate a right-hand-side vector for a linear solve.
|
||||
///
|
||||
/// Checks:
|
||||
///
|
||||
/// 1. `rhs.len() == expected_len` (dimension must match the matrix).
|
||||
/// 2. No `NaN` or `Inf` entries.
|
||||
/// 3. If all entries are zero, emits a [`tracing::warn`] (a zero RHS is
|
||||
/// technically valid but often indicates a bug).
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`ValidationError`] on dimension mismatch or non-finite values.
|
||||
pub fn validate_rhs(rhs: &[f32], expected_len: usize) -> Result<(), ValidationError> {
|
||||
// 1. Length check
|
||||
if rhs.len() != expected_len {
|
||||
return Err(ValidationError::DimensionMismatch(format!(
|
||||
"rhs length {} does not match expected {}",
|
||||
rhs.len(),
|
||||
expected_len,
|
||||
)));
|
||||
}
|
||||
|
||||
// 2. Finite check + 3. All-zeros check
|
||||
let mut all_zero = true;
|
||||
for (i, &v) in rhs.iter().enumerate() {
|
||||
if !v.is_finite() {
|
||||
return Err(ValidationError::NonFiniteValue(format!(
|
||||
"rhs[{}] = {}",
|
||||
i, v,
|
||||
)));
|
||||
}
|
||||
if v != 0.0 {
|
||||
all_zero = false;
|
||||
}
|
||||
}
|
||||
|
||||
if all_zero && !rhs.is_empty() {
|
||||
tracing::warn!("rhs vector is all zeros; solution will be trivially zero");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate the right-hand side vector `b` for compatibility with a matrix.
|
||||
///
|
||||
/// This is an alias for [`validate_rhs`] that preserves backward compatibility
|
||||
/// with the original API name.
|
||||
pub fn validate_rhs_vector(rhs: &[f32], expected_len: usize) -> Result<(), ValidationError> {
|
||||
validate_rhs(rhs, expected_len)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Solver parameter validation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Validate solver convergence parameters.
|
||||
///
|
||||
/// # Rules
|
||||
///
|
||||
/// - `tolerance` must be in the range `(0.0, 1.0]` and be finite.
|
||||
/// - `max_iterations` must be in `[1, MAX_ITERATIONS]`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`ValidationError::ParameterOutOfRange`] if either parameter is
|
||||
/// outside its valid range.
|
||||
pub fn validate_params(tolerance: f64, max_iterations: usize) -> Result<(), ValidationError> {
|
||||
if !tolerance.is_finite() || tolerance <= 0.0 || tolerance > 1.0 {
|
||||
return Err(ValidationError::ParameterOutOfRange {
|
||||
name: "tolerance".into(),
|
||||
value: format!("{tolerance:.2e}"),
|
||||
expected: "(0.0, 1.0]".into(),
|
||||
});
|
||||
}
|
||||
|
||||
if max_iterations == 0 || max_iterations > MAX_ITERATIONS {
|
||||
return Err(ValidationError::ParameterOutOfRange {
|
||||
name: "max_iterations".into(),
|
||||
value: max_iterations.to_string(),
|
||||
expected: format!("[1, {}]", MAX_ITERATIONS),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Combined solver input validation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Validate the complete solver input (matrix + rhs + parameters).
|
||||
///
|
||||
/// This is a convenience function that calls [`validate_csr_matrix`],
|
||||
/// [`validate_rhs`], and validates tolerance in sequence. It also checks
|
||||
/// that the matrix is square, which is required by all iterative solvers.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`ValidationError`] on the first failing check.
|
||||
pub fn validate_solver_input(
|
||||
matrix: &CsrMatrix<f32>,
|
||||
rhs: &[f32],
|
||||
tolerance: f64,
|
||||
) -> Result<(), ValidationError> {
|
||||
validate_csr_matrix(matrix)?;
|
||||
validate_rhs(rhs, matrix.rows)?;
|
||||
|
||||
// Square matrix required for iterative solvers.
|
||||
if matrix.rows != matrix.cols {
|
||||
return Err(ValidationError::DimensionMismatch(format!(
|
||||
"solver requires a square matrix but got {}x{}",
|
||||
matrix.rows, matrix.cols,
|
||||
)));
|
||||
}
|
||||
|
||||
// Tolerance bounds.
|
||||
if !tolerance.is_finite() || tolerance <= 0.0 {
|
||||
return Err(ValidationError::ParameterOutOfRange {
|
||||
name: "tolerance".into(),
|
||||
value: tolerance.to_string(),
|
||||
expected: "finite positive value".into(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Output validation (post-solve)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Validate a solver result after computation completes.
|
||||
///
|
||||
/// This catches silent numerical corruption that may have occurred during
|
||||
/// iteration:
|
||||
///
|
||||
/// 1. No `NaN` or `Inf` in the solution vector.
|
||||
/// 2. The residual norm is finite.
|
||||
/// 3. At least one iteration was performed.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`ValidationError`] if the output is corrupted.
|
||||
pub fn validate_output(result: &SolverResult) -> Result<(), ValidationError> {
|
||||
// 1. Solution vector finiteness
|
||||
for (i, &v) in result.solution.iter().enumerate() {
|
||||
if !v.is_finite() {
|
||||
return Err(ValidationError::NonFiniteValue(format!(
|
||||
"solution[{}] = {}",
|
||||
i, v,
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Residual finiteness
|
||||
if !result.residual_norm.is_finite() {
|
||||
return Err(ValidationError::NonFiniteValue(format!(
|
||||
"residual_norm = {}",
|
||||
result.residual_norm,
|
||||
)));
|
||||
}
|
||||
|
||||
// 3. Iteration count
|
||||
if result.iterations == 0 {
|
||||
return Err(ValidationError::ParameterOutOfRange {
|
||||
name: "iterations".into(),
|
||||
value: "0".into(),
|
||||
expected: ">= 1".into(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Body size validation (for API / deserialization boundaries)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Validate that a request body does not exceed [`MAX_BODY_SIZE`].
|
||||
///
|
||||
/// Call this at the deserialization boundary before parsing untrusted input.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns [`ValidationError::ParameterOutOfRange`] if `size > MAX_BODY_SIZE`.
|
||||
pub fn validate_body_size(size: usize) -> Result<(), ValidationError> {
|
||||
if size > MAX_BODY_SIZE {
|
||||
return Err(ValidationError::ParameterOutOfRange {
|
||||
name: "body_size".into(),
|
||||
value: format!("{} bytes", size),
|
||||
expected: format!("<= {} bytes (10 MiB)", MAX_BODY_SIZE),
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::{Algorithm, ConvergenceInfo, CsrMatrix, SolverResult};
|
||||
use std::time::Duration;
|
||||
|
||||
fn make_identity(n: usize) -> CsrMatrix<f32> {
|
||||
let mut row_ptr = vec![0usize; n + 1];
|
||||
let mut col_indices = Vec::with_capacity(n);
|
||||
let mut values = Vec::with_capacity(n);
|
||||
for i in 0..n {
|
||||
row_ptr[i + 1] = i + 1;
|
||||
col_indices.push(i);
|
||||
values.push(1.0);
|
||||
}
|
||||
CsrMatrix {
|
||||
values,
|
||||
col_indices,
|
||||
row_ptr,
|
||||
rows: n,
|
||||
cols: n,
|
||||
}
|
||||
}
|
||||
|
||||
// -- validate_csr_matrix ------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn valid_identity() {
|
||||
let mat = make_identity(4);
|
||||
assert!(validate_csr_matrix(&mat).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn valid_empty_matrix() {
|
||||
let m = CsrMatrix {
|
||||
row_ptr: vec![0],
|
||||
col_indices: vec![],
|
||||
values: vec![],
|
||||
rows: 0,
|
||||
cols: 0,
|
||||
};
|
||||
assert!(validate_csr_matrix(&m).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn valid_from_coo() {
|
||||
let m = CsrMatrix::<f32>::from_coo(
|
||||
3,
|
||||
3,
|
||||
vec![
|
||||
(0, 0, 2.0),
|
||||
(0, 1, -0.5),
|
||||
(1, 0, -0.5),
|
||||
(1, 1, 2.0),
|
||||
(1, 2, -0.5),
|
||||
(2, 1, -0.5),
|
||||
(2, 2, 2.0),
|
||||
],
|
||||
);
|
||||
assert!(validate_csr_matrix(&m).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_too_large_matrix() {
|
||||
let m = CsrMatrix {
|
||||
row_ptr: vec![0, 0],
|
||||
col_indices: vec![],
|
||||
values: vec![],
|
||||
rows: MAX_NODES + 1,
|
||||
cols: 1,
|
||||
};
|
||||
assert!(matches!(
|
||||
validate_csr_matrix(&m),
|
||||
Err(ValidationError::MatrixTooLarge { .. })
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_wrong_row_ptr_length() {
|
||||
let m = CsrMatrix {
|
||||
row_ptr: vec![0, 1],
|
||||
col_indices: vec![0],
|
||||
values: vec![1.0],
|
||||
rows: 3,
|
||||
cols: 3,
|
||||
};
|
||||
assert!(matches!(
|
||||
validate_csr_matrix(&m),
|
||||
Err(ValidationError::DimensionMismatch(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_monotonic_row_ptr() {
|
||||
let mut mat = make_identity(4);
|
||||
mat.row_ptr[2] = 0; // break monotonicity
|
||||
let err = validate_csr_matrix(&mat).unwrap_err();
|
||||
assert!(matches!(err, ValidationError::NonMonotonicRowPtrs { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_row_ptr_not_starting_at_zero() {
|
||||
let m = CsrMatrix {
|
||||
row_ptr: vec![1, 2],
|
||||
col_indices: vec![0],
|
||||
values: vec![1.0],
|
||||
rows: 1,
|
||||
cols: 1,
|
||||
};
|
||||
match validate_csr_matrix(&m) {
|
||||
Err(ValidationError::DimensionMismatch(msg)) => {
|
||||
assert!(msg.contains("row_ptr[0]"), "msg: {msg}");
|
||||
}
|
||||
other => panic!("expected DimensionMismatch for row_ptr[0], got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn col_index_out_of_bounds() {
|
||||
let mut mat = make_identity(4);
|
||||
mat.col_indices[1] = 99;
|
||||
let err = validate_csr_matrix(&mat).unwrap_err();
|
||||
assert!(matches!(err, ValidationError::IndexOutOfBounds { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn nan_value_rejected() {
|
||||
let mut mat = make_identity(4);
|
||||
mat.values[0] = f32::NAN;
|
||||
let err = validate_csr_matrix(&mat).unwrap_err();
|
||||
assert!(matches!(err, ValidationError::NonFiniteValue(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inf_value_rejected() {
|
||||
let mut mat = make_identity(4);
|
||||
mat.values[0] = f32::INFINITY;
|
||||
let err = validate_csr_matrix(&mat).unwrap_err();
|
||||
assert!(matches!(err, ValidationError::NonFiniteValue(_)));
|
||||
}
|
||||
|
||||
// -- validate_rhs -------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn valid_rhs() {
|
||||
assert!(validate_rhs(&[1.0, 2.0, 3.0], 3).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rhs_dimension_mismatch() {
|
||||
let err = validate_rhs(&[1.0, 2.0], 3).unwrap_err();
|
||||
assert!(matches!(err, ValidationError::DimensionMismatch(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rhs_nan_rejected() {
|
||||
let err = validate_rhs(&[1.0, f32::NAN, 3.0], 3).unwrap_err();
|
||||
assert!(matches!(err, ValidationError::NonFiniteValue(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rhs_inf_rejected() {
|
||||
let err = validate_rhs(&[1.0, f32::NEG_INFINITY, 3.0], 3).unwrap_err();
|
||||
assert!(matches!(err, ValidationError::NonFiniteValue(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn warns_on_all_zero_rhs() {
|
||||
// Should succeed but emit a warning (cannot assert warning in unit test,
|
||||
// but at least verify it does not error).
|
||||
assert!(validate_rhs(&[0.0, 0.0, 0.0], 3).is_ok());
|
||||
}
|
||||
|
||||
// -- validate_rhs_vector (backward compat alias) ------------------------
|
||||
|
||||
#[test]
|
||||
fn rhs_vector_alias_works() {
|
||||
assert!(validate_rhs_vector(&[1.0, 2.0], 2).is_ok());
|
||||
assert!(validate_rhs_vector(&[1.0, 2.0], 3).is_err());
|
||||
}
|
||||
|
||||
// -- validate_params ----------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn valid_params() {
|
||||
assert!(validate_params(1e-8, 500).is_ok());
|
||||
assert!(validate_params(1.0, 1).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_zero_tolerance() {
|
||||
match validate_params(0.0, 100) {
|
||||
Err(ValidationError::ParameterOutOfRange { ref name, .. }) => {
|
||||
assert_eq!(name, "tolerance");
|
||||
}
|
||||
other => panic!("expected ParameterOutOfRange for tolerance, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_negative_tolerance() {
|
||||
match validate_params(-1e-6, 100) {
|
||||
Err(ValidationError::ParameterOutOfRange { ref name, .. }) => {
|
||||
assert_eq!(name, "tolerance");
|
||||
}
|
||||
other => panic!("expected ParameterOutOfRange for tolerance, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_tolerance_above_one() {
|
||||
match validate_params(1.5, 100) {
|
||||
Err(ValidationError::ParameterOutOfRange { ref name, .. }) => {
|
||||
assert_eq!(name, "tolerance");
|
||||
}
|
||||
other => panic!("expected ParameterOutOfRange for tolerance, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_nan_tolerance() {
|
||||
match validate_params(f64::NAN, 100) {
|
||||
Err(ValidationError::ParameterOutOfRange { ref name, .. }) => {
|
||||
assert_eq!(name, "tolerance");
|
||||
}
|
||||
other => panic!("expected ParameterOutOfRange for tolerance, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_zero_iterations() {
|
||||
match validate_params(1e-6, 0) {
|
||||
Err(ValidationError::ParameterOutOfRange { ref name, .. }) => {
|
||||
assert_eq!(name, "max_iterations");
|
||||
}
|
||||
other => panic!("expected ParameterOutOfRange for max_iterations, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_excessive_iterations() {
|
||||
match validate_params(1e-6, MAX_ITERATIONS + 1) {
|
||||
Err(ValidationError::ParameterOutOfRange { ref name, .. }) => {
|
||||
assert_eq!(name, "max_iterations");
|
||||
}
|
||||
other => panic!("expected ParameterOutOfRange for max_iterations, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
// -- validate_solver_input (combined) -----------------------------------
|
||||
|
||||
#[test]
|
||||
fn full_input_validation() {
|
||||
let mat = make_identity(3);
|
||||
let rhs = vec![1.0f32, 2.0, 3.0];
|
||||
assert!(validate_solver_input(&mat, &rhs, 1e-6).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_square_rejected() {
|
||||
let mat = CsrMatrix {
|
||||
values: vec![],
|
||||
col_indices: vec![],
|
||||
row_ptr: vec![0, 0, 0],
|
||||
rows: 2,
|
||||
cols: 3,
|
||||
};
|
||||
let rhs = vec![1.0f32, 2.0];
|
||||
let err = validate_solver_input(&mat, &rhs, 1e-6).unwrap_err();
|
||||
assert!(matches!(err, ValidationError::DimensionMismatch(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn invalid_tolerance_rejected() {
|
||||
let mat = make_identity(2);
|
||||
let rhs = vec![1.0f32, 2.0];
|
||||
assert!(validate_solver_input(&mat, &rhs, -1.0).is_err());
|
||||
assert!(validate_solver_input(&mat, &rhs, 0.0).is_err());
|
||||
assert!(validate_solver_input(&mat, &rhs, f64::NAN).is_err());
|
||||
}
|
||||
|
||||
// -- validate_output ----------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn valid_output() {
|
||||
let result = SolverResult {
|
||||
solution: vec![1.0, 2.0, 3.0],
|
||||
iterations: 10,
|
||||
residual_norm: 1e-8,
|
||||
wall_time: Duration::from_millis(5),
|
||||
convergence_history: vec![ConvergenceInfo {
|
||||
iteration: 0,
|
||||
residual_norm: 1.0,
|
||||
}],
|
||||
algorithm: Algorithm::Neumann,
|
||||
};
|
||||
assert!(validate_output(&result).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_nan_in_solution() {
|
||||
let result = SolverResult {
|
||||
solution: vec![1.0, f32::NAN, 3.0],
|
||||
iterations: 1,
|
||||
residual_norm: 1e-8,
|
||||
wall_time: Duration::from_millis(1),
|
||||
convergence_history: vec![],
|
||||
algorithm: Algorithm::Neumann,
|
||||
};
|
||||
match validate_output(&result) {
|
||||
Err(ValidationError::NonFiniteValue(ref msg)) => {
|
||||
assert!(msg.contains("solution"), "msg: {msg}");
|
||||
}
|
||||
other => panic!("expected NonFiniteValue for solution, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_inf_in_solution() {
|
||||
let result = SolverResult {
|
||||
solution: vec![f32::INFINITY],
|
||||
iterations: 1,
|
||||
residual_norm: 1e-8,
|
||||
wall_time: Duration::from_millis(1),
|
||||
convergence_history: vec![],
|
||||
algorithm: Algorithm::Neumann,
|
||||
};
|
||||
match validate_output(&result) {
|
||||
Err(ValidationError::NonFiniteValue(ref msg)) => {
|
||||
assert!(msg.contains("solution"), "msg: {msg}");
|
||||
}
|
||||
other => panic!("expected NonFiniteValue for solution, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_nan_residual() {
|
||||
let result = SolverResult {
|
||||
solution: vec![1.0],
|
||||
iterations: 1,
|
||||
residual_norm: f64::NAN,
|
||||
wall_time: Duration::from_millis(1),
|
||||
convergence_history: vec![],
|
||||
algorithm: Algorithm::Neumann,
|
||||
};
|
||||
match validate_output(&result) {
|
||||
Err(ValidationError::NonFiniteValue(ref msg)) => {
|
||||
assert!(msg.contains("residual"), "msg: {msg}");
|
||||
}
|
||||
other => panic!("expected NonFiniteValue for residual, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_inf_residual() {
|
||||
let result = SolverResult {
|
||||
solution: vec![1.0],
|
||||
iterations: 1,
|
||||
residual_norm: f64::INFINITY,
|
||||
wall_time: Duration::from_millis(1),
|
||||
convergence_history: vec![],
|
||||
algorithm: Algorithm::Neumann,
|
||||
};
|
||||
assert!(matches!(
|
||||
validate_output(&result),
|
||||
Err(ValidationError::NonFiniteValue(_))
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_zero_iterations_in_output() {
|
||||
let result = SolverResult {
|
||||
solution: vec![1.0],
|
||||
iterations: 0,
|
||||
residual_norm: 1e-8,
|
||||
wall_time: Duration::from_millis(1),
|
||||
convergence_history: vec![],
|
||||
algorithm: Algorithm::Neumann,
|
||||
};
|
||||
match validate_output(&result) {
|
||||
Err(ValidationError::ParameterOutOfRange { ref name, .. }) => {
|
||||
assert_eq!(name, "iterations");
|
||||
}
|
||||
other => panic!("expected ParameterOutOfRange, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
// -- validate_body_size -------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn valid_body_size() {
|
||||
assert!(validate_body_size(1024).is_ok());
|
||||
assert!(validate_body_size(MAX_BODY_SIZE).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_oversized_body() {
|
||||
match validate_body_size(MAX_BODY_SIZE + 1) {
|
||||
Err(ValidationError::ParameterOutOfRange { ref name, .. }) => {
|
||||
assert_eq!(name, "body_size");
|
||||
}
|
||||
other => panic!("expected ParameterOutOfRange, got {other:?}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user