git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
451 lines
12 KiB
Rust
451 lines
12 KiB
Rust
// Accuracy validation tests
|
|
//
|
|
// Tests OCR accuracy against Im2latex-100k subset and calculates CER, WER, BLEU
|
|
|
|
use super::*;
|
|
use tokio;
|
|
|
|
#[tokio::test]
|
|
async fn test_accuracy_simple_expressions() {
|
|
let test_server = TestServer::start()
|
|
.await
|
|
.expect("Failed to start test server");
|
|
|
|
let test_cases = vec![
|
|
("x + 1", "x + 1"),
|
|
("2x - 3", "2x - 3"),
|
|
("a = b", "a = b"),
|
|
("f(x)", "f(x)"),
|
|
("y^2", "y^2"),
|
|
];
|
|
|
|
let mut total_cer = 0.0;
|
|
let mut correct = 0;
|
|
|
|
for (equation, expected) in test_cases.iter() {
|
|
let image = images::generate_simple_equation(equation);
|
|
let path = format!("/tmp/accuracy_simple_{}.png", equation.replace(' ', "_"));
|
|
image.save(&path).unwrap();
|
|
|
|
let result = test_server
|
|
.process_image(&path, OutputFormat::LaTeX)
|
|
.await
|
|
.expect("Processing failed");
|
|
|
|
let cer = metrics::calculate_cer(expected, &result.latex);
|
|
total_cer += cer;
|
|
|
|
if latex::normalize(&result.latex) == latex::normalize(expected) {
|
|
correct += 1;
|
|
}
|
|
|
|
println!(
|
|
"Equation: {} | CER: {:.4} | Got: {}",
|
|
equation, cer, result.latex
|
|
);
|
|
}
|
|
|
|
let avg_cer = total_cer / test_cases.len() as f64;
|
|
let accuracy = correct as f64 / test_cases.len() as f64;
|
|
|
|
println!(
|
|
"Simple expressions - Avg CER: {:.4}, Accuracy: {:.2}%",
|
|
avg_cer,
|
|
accuracy * 100.0
|
|
);
|
|
|
|
assert!(avg_cer < 0.05, "Average CER too high: {:.4}", avg_cer);
|
|
assert!(
|
|
accuracy > 0.90,
|
|
"Accuracy too low: {:.2}%",
|
|
accuracy * 100.0
|
|
);
|
|
|
|
test_server.shutdown().await;
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_accuracy_im2latex_subset() {
|
|
let test_server = TestServer::start()
|
|
.await
|
|
.expect("Failed to start test server");
|
|
|
|
// Load Im2latex-100k test subset (sample)
|
|
let test_cases = load_im2latex_test_subset(50); // Test 50 samples
|
|
|
|
let mut cer_sum = 0.0;
|
|
let mut wer_sum = 0.0;
|
|
let mut bleu_sum = 0.0;
|
|
let mut exact_matches = 0;
|
|
|
|
for (i, case) in test_cases.iter().enumerate() {
|
|
// Generate or load image
|
|
let image_path = case.image_path.clone();
|
|
|
|
let result = test_server
|
|
.process_image(&image_path, OutputFormat::LaTeX)
|
|
.await
|
|
.expect("Processing failed");
|
|
|
|
// Calculate metrics
|
|
let cer = metrics::calculate_cer(&case.ground_truth, &result.latex);
|
|
let wer = metrics::calculate_wer(&case.ground_truth, &result.latex);
|
|
let bleu = metrics::calculate_bleu(&case.ground_truth, &result.latex, 4);
|
|
|
|
cer_sum += cer;
|
|
wer_sum += wer;
|
|
bleu_sum += bleu;
|
|
|
|
if latex::normalize(&result.latex) == latex::normalize(&case.ground_truth) {
|
|
exact_matches += 1;
|
|
}
|
|
|
|
if i % 10 == 0 {
|
|
println!("Processed {}/{} samples", i + 1, test_cases.len());
|
|
}
|
|
}
|
|
|
|
let count = test_cases.len() as f64;
|
|
let avg_cer = cer_sum / count;
|
|
let avg_wer = wer_sum / count;
|
|
let avg_bleu = bleu_sum / count;
|
|
let exact_match_rate = exact_matches as f64 / count;
|
|
|
|
println!("\nIm2latex subset results:");
|
|
println!(" Average CER: {:.4}", avg_cer);
|
|
println!(" Average WER: {:.4}", avg_wer);
|
|
println!(" Average BLEU: {:.2}", avg_bleu);
|
|
println!(" Exact match rate: {:.2}%", exact_match_rate * 100.0);
|
|
|
|
// Assert quality thresholds
|
|
assert!(avg_cer < 0.03, "CER too high: {:.4}", avg_cer);
|
|
assert!(avg_bleu > 80.0, "BLEU too low: {:.2}", avg_bleu);
|
|
|
|
test_server.shutdown().await;
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_accuracy_fractions() {
|
|
let test_server = TestServer::start()
|
|
.await
|
|
.expect("Failed to start test server");
|
|
|
|
let test_cases = vec![
|
|
((1, 2), r"\frac{1}{2}"),
|
|
((3, 4), r"\frac{3}{4}"),
|
|
((5, 6), r"\frac{5}{6}"),
|
|
((10, 3), r"\frac{10}{3}"),
|
|
];
|
|
|
|
let mut correct = 0;
|
|
|
|
for ((num, den), expected) in test_cases.iter() {
|
|
let image = images::generate_fraction(*num, *den);
|
|
let path = format!("/tmp/frac_{}_{}.png", num, den);
|
|
image.save(&path).unwrap();
|
|
|
|
let result = test_server
|
|
.process_image(&path, OutputFormat::LaTeX)
|
|
.await
|
|
.expect("Processing failed");
|
|
|
|
if latex::expressions_match(&result.latex, expected) {
|
|
correct += 1;
|
|
} else {
|
|
println!(
|
|
"Fraction {}/{} - Expected: {}, Got: {}",
|
|
num, den, expected, result.latex
|
|
);
|
|
}
|
|
}
|
|
|
|
let accuracy = correct as f64 / test_cases.len() as f64;
|
|
println!("Fraction accuracy: {:.2}%", accuracy * 100.0);
|
|
|
|
assert!(
|
|
accuracy >= 0.85,
|
|
"Fraction accuracy too low: {:.2}%",
|
|
accuracy * 100.0
|
|
);
|
|
|
|
test_server.shutdown().await;
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_accuracy_special_symbols() {
|
|
let test_server = TestServer::start()
|
|
.await
|
|
.expect("Failed to start test server");
|
|
|
|
let test_cases = vec![
|
|
(r"\alpha", r"\alpha"),
|
|
(r"\beta", r"\beta"),
|
|
(r"\sum", r"\sum"),
|
|
(r"\int", r"\int"),
|
|
(r"\pi", r"\pi"),
|
|
(r"\infty", r"\infty"),
|
|
];
|
|
|
|
let mut correct = 0;
|
|
|
|
for (symbol, expected) in test_cases.iter() {
|
|
let image = images::generate_symbol(symbol);
|
|
let path = format!("/tmp/symbol_{}.png", symbol.replace('\\', ""));
|
|
image.save(&path).unwrap();
|
|
|
|
let result = test_server
|
|
.process_image(&path, OutputFormat::LaTeX)
|
|
.await
|
|
.expect("Processing failed");
|
|
|
|
if result.latex.contains(expected) {
|
|
correct += 1;
|
|
} else {
|
|
println!(
|
|
"Symbol {} - Expected to contain: {}, Got: {}",
|
|
symbol, expected, result.latex
|
|
);
|
|
}
|
|
}
|
|
|
|
let accuracy = correct as f64 / test_cases.len() as f64;
|
|
println!("Special symbol accuracy: {:.2}%", accuracy * 100.0);
|
|
|
|
assert!(
|
|
accuracy >= 0.80,
|
|
"Symbol accuracy too low: {:.2}%",
|
|
accuracy * 100.0
|
|
);
|
|
|
|
test_server.shutdown().await;
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_accuracy_regression_detection() {
|
|
let test_server = TestServer::start()
|
|
.await
|
|
.expect("Failed to start test server");
|
|
|
|
// Load baseline results
|
|
let baseline = load_baseline_results();
|
|
|
|
// Run same test cases
|
|
let test_cases = load_regression_test_cases();
|
|
|
|
let mut regressions = Vec::new();
|
|
|
|
for case in test_cases.iter() {
|
|
let result = test_server
|
|
.process_image(&case.image_path, OutputFormat::LaTeX)
|
|
.await
|
|
.expect("Processing failed");
|
|
|
|
// Compare with baseline
|
|
if let Some(baseline_result) = baseline.get(&case.id) {
|
|
let current_cer = metrics::calculate_cer(&case.ground_truth, &result.latex);
|
|
let baseline_cer = baseline_result.cer;
|
|
|
|
// Check for regression (10% threshold)
|
|
if current_cer > baseline_cer * 1.10 {
|
|
regressions.push((
|
|
case.id.clone(),
|
|
baseline_cer,
|
|
current_cer,
|
|
baseline_result.latex.clone(),
|
|
result.latex.clone(),
|
|
));
|
|
}
|
|
}
|
|
}
|
|
|
|
if !regressions.is_empty() {
|
|
println!("Regressions detected:");
|
|
for (id, baseline_cer, current_cer, baseline_latex, current_latex) in ®ressions {
|
|
println!(" {} - CER: {:.4} -> {:.4}", id, baseline_cer, current_cer);
|
|
println!(" Baseline: {}", baseline_latex);
|
|
println!(" Current: {}", current_latex);
|
|
}
|
|
}
|
|
|
|
assert!(
|
|
regressions.is_empty(),
|
|
"Found {} regressions",
|
|
regressions.len()
|
|
);
|
|
|
|
test_server.shutdown().await;
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_accuracy_confidence_calibration() {
|
|
let test_server = TestServer::start()
|
|
.await
|
|
.expect("Failed to start test server");
|
|
|
|
let test_cases = load_calibration_test_cases();
|
|
|
|
let mut high_conf_correct = 0;
|
|
let mut high_conf_total = 0;
|
|
let mut low_conf_correct = 0;
|
|
let mut low_conf_total = 0;
|
|
|
|
for case in test_cases.iter() {
|
|
let result = test_server
|
|
.process_image(&case.image_path, OutputFormat::LaTeX)
|
|
.await
|
|
.expect("Processing failed");
|
|
|
|
let is_correct = latex::normalize(&result.latex) == latex::normalize(&case.ground_truth);
|
|
|
|
if result.confidence > 0.9 {
|
|
high_conf_total += 1;
|
|
if is_correct {
|
|
high_conf_correct += 1;
|
|
}
|
|
} else if result.confidence < 0.7 {
|
|
low_conf_total += 1;
|
|
if is_correct {
|
|
low_conf_correct += 1;
|
|
}
|
|
}
|
|
}
|
|
|
|
let high_conf_accuracy = if high_conf_total > 0 {
|
|
high_conf_correct as f64 / high_conf_total as f64
|
|
} else {
|
|
1.0
|
|
};
|
|
|
|
let low_conf_accuracy = if low_conf_total > 0 {
|
|
low_conf_correct as f64 / low_conf_total as f64
|
|
} else {
|
|
0.0
|
|
};
|
|
|
|
println!("Confidence calibration:");
|
|
println!(
|
|
" High confidence (>0.9): {:.2}% accuracy ({}/{})",
|
|
high_conf_accuracy * 100.0,
|
|
high_conf_correct,
|
|
high_conf_total
|
|
);
|
|
println!(
|
|
" Low confidence (<0.7): {:.2}% accuracy ({}/{})",
|
|
low_conf_accuracy * 100.0,
|
|
low_conf_correct,
|
|
low_conf_total
|
|
);
|
|
|
|
// High confidence should correlate with high accuracy
|
|
assert!(
|
|
high_conf_accuracy > 0.95,
|
|
"High confidence predictions should be very accurate"
|
|
);
|
|
|
|
test_server.shutdown().await;
|
|
}
|
|
|
|
// Helper functions and types
|
|
|
|
#[derive(Debug, Clone)]
|
|
struct TestCase {
|
|
id: String,
|
|
image_path: String,
|
|
ground_truth: String,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
struct BaselineResult {
|
|
latex: String,
|
|
cer: f64,
|
|
}
|
|
|
|
fn load_im2latex_test_subset(count: usize) -> Vec<TestCase> {
|
|
// Load or generate Im2latex test subset
|
|
// For now, generate synthetic test cases
|
|
(0..count)
|
|
.map(|i| {
|
|
let eq = match i % 5 {
|
|
0 => format!("x^{}", i),
|
|
1 => format!("a + {}", i),
|
|
2 => format!(r"\frac{{{}}}{{{}}}", i, i + 1),
|
|
3 => format!("{}x + {}", i, i * 2),
|
|
_ => format!("y = {}x", i),
|
|
};
|
|
|
|
let image = images::generate_simple_equation(&eq);
|
|
let path = format!("/tmp/im2latex_{}.png", i);
|
|
image.save(&path).unwrap();
|
|
|
|
TestCase {
|
|
id: format!("im2latex_{}", i),
|
|
image_path: path,
|
|
ground_truth: eq,
|
|
}
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
fn load_regression_test_cases() -> Vec<TestCase> {
|
|
// Load regression test cases from file or generate
|
|
vec![
|
|
TestCase {
|
|
id: "reg_001".to_string(),
|
|
image_path: "/tmp/reg_001.png".to_string(),
|
|
ground_truth: "x + y".to_string(),
|
|
},
|
|
// Add more test cases...
|
|
]
|
|
}
|
|
|
|
fn load_baseline_results() -> std::collections::HashMap<String, BaselineResult> {
|
|
// Load baseline results from file
|
|
let mut baseline = std::collections::HashMap::new();
|
|
|
|
baseline.insert(
|
|
"reg_001".to_string(),
|
|
BaselineResult {
|
|
latex: "x + y".to_string(),
|
|
cer: 0.0,
|
|
},
|
|
);
|
|
|
|
baseline
|
|
}
|
|
|
|
fn load_calibration_test_cases() -> Vec<TestCase> {
|
|
// Generate test cases with varying difficulty for confidence calibration
|
|
let mut cases = Vec::new();
|
|
|
|
// Easy cases
|
|
for i in 0..10 {
|
|
let eq = format!("x + {}", i);
|
|
let image = images::generate_simple_equation(&eq);
|
|
let path = format!("/tmp/calib_easy_{}.png", i);
|
|
image.save(&path).unwrap();
|
|
|
|
cases.push(TestCase {
|
|
id: format!("calib_easy_{}", i),
|
|
image_path: path,
|
|
ground_truth: eq,
|
|
});
|
|
}
|
|
|
|
// Hard cases (noisy)
|
|
for i in 0..10 {
|
|
let eq = format!("y^{}", i);
|
|
let mut image = images::generate_simple_equation(&eq);
|
|
images::add_noise(&mut image, 0.2);
|
|
let path = format!("/tmp/calib_hard_{}.png", i);
|
|
image.save(&path).unwrap();
|
|
|
|
cases.push(TestCase {
|
|
id: format!("calib_hard_{}", i),
|
|
image_path: path,
|
|
ground_truth: eq,
|
|
});
|
|
}
|
|
|
|
cases
|
|
}
|