Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
299
crates/ruvector-solver/tests/test_validation.rs
Normal file
299
crates/ruvector-solver/tests/test_validation.rs
Normal file
@@ -0,0 +1,299 @@
|
||||
//! Integration tests for input validation.
|
||||
//!
|
||||
//! Tests cover rejection of NaN and Inf values, dimension mismatches,
|
||||
//! malformed CSR structures (non-monotonic row_ptrs), and oversized inputs.
|
||||
|
||||
use ruvector_solver::error::ValidationError;
|
||||
use ruvector_solver::types::CsrMatrix;
|
||||
use ruvector_solver::validation::{
|
||||
validate_csr_matrix, validate_rhs, validate_solver_input, MAX_NODES,
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helper: build a valid f32 identity matrix
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn identity_f32(n: usize) -> CsrMatrix<f32> {
|
||||
let row_ptr: Vec<usize> = (0..=n).collect();
|
||||
let col_indices: Vec<usize> = (0..n).collect();
|
||||
let values = vec![1.0f32; n];
|
||||
CsrMatrix {
|
||||
row_ptr,
|
||||
col_indices,
|
||||
values,
|
||||
rows: n,
|
||||
cols: n,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build a small valid 3x3 f32 CSR matrix for testing.
|
||||
fn valid_3x3_f32() -> CsrMatrix<f32> {
|
||||
CsrMatrix::<f32>::from_coo(
|
||||
3,
|
||||
3,
|
||||
vec![
|
||||
(0, 0, 2.0),
|
||||
(0, 1, -0.5),
|
||||
(1, 0, -0.5),
|
||||
(1, 1, 2.0),
|
||||
(1, 2, -0.5),
|
||||
(2, 1, -0.5),
|
||||
(2, 2, 2.0),
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Reject NaN in matrix values
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_reject_nan_input() {
|
||||
// NaN in matrix values should be rejected.
|
||||
let mut mat = identity_f32(4);
|
||||
mat.values[2] = f32::NAN;
|
||||
|
||||
let err = validate_csr_matrix(&mat).unwrap_err();
|
||||
assert!(
|
||||
matches!(err, ValidationError::NonFiniteValue(_)),
|
||||
"expected NonFiniteValue for NaN in matrix, got {:?}",
|
||||
err
|
||||
);
|
||||
|
||||
// NaN in RHS should be rejected.
|
||||
let rhs = vec![1.0f32, f32::NAN, 3.0];
|
||||
let err_rhs = validate_rhs(&rhs, 3).unwrap_err();
|
||||
assert!(
|
||||
matches!(err_rhs, ValidationError::NonFiniteValue(_)),
|
||||
"expected NonFiniteValue for NaN in RHS, got {:?}",
|
||||
err_rhs
|
||||
);
|
||||
|
||||
// Combined validation should also catch NaN in matrix.
|
||||
let mut mat2 = valid_3x3_f32();
|
||||
mat2.values[0] = f32::NAN;
|
||||
let err_combined = validate_solver_input(&mat2, &[1.0, 2.0, 3.0], 1e-6).unwrap_err();
|
||||
assert!(
|
||||
matches!(err_combined, ValidationError::NonFiniteValue(_)),
|
||||
"combined validation should reject NaN matrix values"
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Reject Inf values
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_reject_inf_input() {
|
||||
// Positive infinity in matrix values.
|
||||
let mut mat = identity_f32(3);
|
||||
mat.values[0] = f32::INFINITY;
|
||||
|
||||
let err = validate_csr_matrix(&mat).unwrap_err();
|
||||
assert!(
|
||||
matches!(err, ValidationError::NonFiniteValue(_)),
|
||||
"expected NonFiniteValue for +Inf in matrix, got {:?}",
|
||||
err
|
||||
);
|
||||
|
||||
// Negative infinity in matrix values.
|
||||
let mut mat2 = identity_f32(3);
|
||||
mat2.values[1] = f32::NEG_INFINITY;
|
||||
|
||||
let err2 = validate_csr_matrix(&mat2).unwrap_err();
|
||||
assert!(
|
||||
matches!(err2, ValidationError::NonFiniteValue(_)),
|
||||
"expected NonFiniteValue for -Inf in matrix, got {:?}",
|
||||
err2
|
||||
);
|
||||
|
||||
// Infinity in RHS.
|
||||
let rhs = vec![1.0f32, f32::INFINITY, 3.0];
|
||||
let err_rhs = validate_rhs(&rhs, 3).unwrap_err();
|
||||
assert!(
|
||||
matches!(err_rhs, ValidationError::NonFiniteValue(_)),
|
||||
"expected NonFiniteValue for Inf in RHS, got {:?}",
|
||||
err_rhs
|
||||
);
|
||||
|
||||
// Negative infinity in RHS.
|
||||
let rhs_neg = vec![f32::NEG_INFINITY, 2.0, 3.0];
|
||||
let err_neg = validate_rhs(&rhs_neg, 3).unwrap_err();
|
||||
assert!(
|
||||
matches!(err_neg, ValidationError::NonFiniteValue(_)),
|
||||
"expected NonFiniteValue for -Inf in RHS, got {:?}",
|
||||
err_neg
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Reject dimension mismatch: rhs length != matrix rows
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_reject_dimension_mismatch() {
|
||||
// RHS too short.
|
||||
let err_short = validate_rhs(&[1.0f32, 2.0], 3).unwrap_err();
|
||||
assert!(
|
||||
matches!(err_short, ValidationError::DimensionMismatch(_)),
|
||||
"expected DimensionMismatch for rhs length < expected, got {:?}",
|
||||
err_short
|
||||
);
|
||||
|
||||
// RHS too long.
|
||||
let err_long = validate_rhs(&[1.0f32, 2.0, 3.0, 4.0], 3).unwrap_err();
|
||||
assert!(
|
||||
matches!(err_long, ValidationError::DimensionMismatch(_)),
|
||||
"expected DimensionMismatch for rhs length > expected, got {:?}",
|
||||
err_long
|
||||
);
|
||||
|
||||
// Combined solver input validation: rhs doesn't match matrix rows.
|
||||
let mat = valid_3x3_f32();
|
||||
let rhs_wrong = vec![1.0f32, 2.0]; // length 2, but matrix is 3x3
|
||||
let err_combined = validate_solver_input(&mat, &rhs_wrong, 1e-6).unwrap_err();
|
||||
assert!(
|
||||
matches!(err_combined, ValidationError::DimensionMismatch(_)),
|
||||
"combined validation should reject dimension mismatch, got {:?}",
|
||||
err_combined
|
||||
);
|
||||
|
||||
// Non-square matrix should be rejected by validate_solver_input.
|
||||
let non_square = CsrMatrix::<f32> {
|
||||
row_ptr: vec![0, 0, 0],
|
||||
col_indices: vec![],
|
||||
values: vec![],
|
||||
rows: 2,
|
||||
cols: 3,
|
||||
};
|
||||
let rhs_ns = vec![1.0f32, 2.0];
|
||||
let err_ns = validate_solver_input(&non_square, &rhs_ns, 1e-6).unwrap_err();
|
||||
assert!(
|
||||
matches!(err_ns, ValidationError::DimensionMismatch(_)),
|
||||
"non-square matrix should be rejected, got {:?}",
|
||||
err_ns
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Reject invalid CSR: row_ptrs not monotonic
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_reject_invalid_csr() {
|
||||
// Break monotonicity of row_ptr.
|
||||
let mut mat = identity_f32(4);
|
||||
// row_ptr is [0, 1, 2, 3, 4]. Set row_ptr[2] = 0 to break monotonicity.
|
||||
mat.row_ptr[2] = 0;
|
||||
|
||||
let err = validate_csr_matrix(&mat).unwrap_err();
|
||||
assert!(
|
||||
matches!(err, ValidationError::NonMonotonicRowPtrs { .. }),
|
||||
"expected NonMonotonicRowPtrs, got {:?}",
|
||||
err
|
||||
);
|
||||
|
||||
// Verify the position is reported.
|
||||
if let ValidationError::NonMonotonicRowPtrs { position } = err {
|
||||
assert_eq!(
|
||||
position, 2,
|
||||
"monotonicity violation should be at position 2"
|
||||
);
|
||||
}
|
||||
|
||||
// row_ptr[0] != 0 should also be rejected.
|
||||
let bad_start = CsrMatrix::<f32> {
|
||||
row_ptr: vec![1, 2],
|
||||
col_indices: vec![0],
|
||||
values: vec![1.0],
|
||||
rows: 1,
|
||||
cols: 1,
|
||||
};
|
||||
let err_start = validate_csr_matrix(&bad_start).unwrap_err();
|
||||
assert!(
|
||||
matches!(err_start, ValidationError::DimensionMismatch(_)),
|
||||
"row_ptr[0] != 0 should be rejected, got {:?}",
|
||||
err_start
|
||||
);
|
||||
|
||||
// row_ptr length != rows + 1 should be rejected.
|
||||
let bad_len = CsrMatrix::<f32> {
|
||||
row_ptr: vec![0, 1], // length 2, but rows = 3 so expect length 4
|
||||
col_indices: vec![0],
|
||||
values: vec![1.0],
|
||||
rows: 3,
|
||||
cols: 3,
|
||||
};
|
||||
let err_len = validate_csr_matrix(&bad_len).unwrap_err();
|
||||
assert!(
|
||||
matches!(err_len, ValidationError::DimensionMismatch(_)),
|
||||
"wrong row_ptr length should be rejected, got {:?}",
|
||||
err_len
|
||||
);
|
||||
|
||||
// Column index out of bounds.
|
||||
let mut bad_col = identity_f32(3);
|
||||
bad_col.col_indices[1] = 99; // out of bounds
|
||||
let err_col = validate_csr_matrix(&bad_col).unwrap_err();
|
||||
assert!(
|
||||
matches!(err_col, ValidationError::IndexOutOfBounds { .. }),
|
||||
"column index out of bounds should be rejected, got {:?}",
|
||||
err_col
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Reject oversized input: exceeds MAX_NODES
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_reject_oversized_input() {
|
||||
// Matrix with rows > MAX_NODES.
|
||||
let oversized = CsrMatrix::<f32> {
|
||||
row_ptr: vec![0; MAX_NODES + 2], // length = (MAX_NODES + 1) + 1
|
||||
col_indices: vec![],
|
||||
values: vec![],
|
||||
rows: MAX_NODES + 1,
|
||||
cols: 1,
|
||||
};
|
||||
|
||||
let err = validate_csr_matrix(&oversized).unwrap_err();
|
||||
assert!(
|
||||
matches!(err, ValidationError::MatrixTooLarge { .. }),
|
||||
"expected MatrixTooLarge for rows > MAX_NODES, got {:?}",
|
||||
err
|
||||
);
|
||||
|
||||
// Verify the reported dimensions.
|
||||
if let ValidationError::MatrixTooLarge {
|
||||
rows,
|
||||
cols,
|
||||
max_dim,
|
||||
} = err
|
||||
{
|
||||
assert_eq!(rows, MAX_NODES + 1);
|
||||
assert_eq!(cols, 1);
|
||||
assert_eq!(max_dim, MAX_NODES);
|
||||
}
|
||||
|
||||
// Matrix with cols > MAX_NODES.
|
||||
let oversized_cols = CsrMatrix::<f32> {
|
||||
row_ptr: vec![0, 0],
|
||||
col_indices: vec![],
|
||||
values: vec![],
|
||||
rows: 1,
|
||||
cols: MAX_NODES + 1,
|
||||
};
|
||||
|
||||
let err_cols = validate_csr_matrix(&oversized_cols).unwrap_err();
|
||||
assert!(
|
||||
matches!(err_cols, ValidationError::MatrixTooLarge { .. }),
|
||||
"expected MatrixTooLarge for cols > MAX_NODES, got {:?}",
|
||||
err_cols
|
||||
);
|
||||
|
||||
// A matrix at exactly MAX_NODES should be accepted (boundary check).
|
||||
// We cannot allocate MAX_NODES entries in a test, but verify the logic:
|
||||
// MAX_NODES rows with 0 nnz should be valid structurally.
|
||||
// (Skipping actual allocation of 10M-entry row_ptr for test speed.)
|
||||
}
|
||||
Reference in New Issue
Block a user