Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,384 @@
//! Confidence Scoring Module
//!
//! This module provides confidence scoring and calibration for OCR results.
//! It includes per-character confidence calculation and aggregation methods.
use super::Result;
use std::collections::HashMap;
use tracing::debug;
/// Calculate confidence score for a single character prediction
///
/// # Arguments
/// * `logits` - Raw logits from the model for this character position
///
/// # Returns
/// Confidence score between 0.0 and 1.0
pub fn calculate_confidence(logits: &[f32]) -> f32 {
if logits.is_empty() {
return 0.0;
}
// Apply softmax to get probabilities
let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp_sum: f32 = logits.iter().map(|&x| (x - max_logit).exp()).sum();
// Return the maximum probability
let max_prob = logits
.iter()
.map(|&x| (x - max_logit).exp() / exp_sum)
.fold(0.0f32, |a, b| a.max(b));
max_prob.clamp(0.0, 1.0)
}
/// Aggregate multiple confidence scores into a single score
///
/// # Arguments
/// * `confidences` - Individual confidence scores
///
/// # Returns
/// Aggregated confidence score using geometric mean
pub fn aggregate_confidence(confidences: &[f32]) -> f32 {
if confidences.is_empty() {
return 0.0;
}
// Use geometric mean for aggregation (more conservative than arithmetic mean)
let product: f32 = confidences.iter().product();
let n = confidences.len() as f32;
product.powf(1.0 / n).clamp(0.0, 1.0)
}
/// Alternative aggregation using arithmetic mean
pub fn aggregate_confidence_mean(confidences: &[f32]) -> f32 {
if confidences.is_empty() {
return 0.0;
}
let sum: f32 = confidences.iter().sum();
(sum / confidences.len() as f32).clamp(0.0, 1.0)
}
/// Alternative aggregation using minimum (most conservative)
pub fn aggregate_confidence_min(confidences: &[f32]) -> f32 {
confidences
.iter()
.fold(1.0f32, |a, &b| a.min(b))
.clamp(0.0, 1.0)
}
/// Alternative aggregation using harmonic mean
pub fn aggregate_confidence_harmonic(confidences: &[f32]) -> f32 {
if confidences.is_empty() {
return 0.0;
}
let sum_reciprocals: f32 = confidences.iter().map(|&c| 1.0 / c.max(0.001)).sum();
let n = confidences.len() as f32;
(n / sum_reciprocals).clamp(0.0, 1.0)
}
/// Confidence calibrator using isotonic regression
///
/// This calibrator learns a mapping from raw confidence scores to calibrated
/// probabilities using historical data.
pub struct ConfidenceCalibrator {
/// Calibration mapping: raw_score -> calibrated_score
calibration_map: HashMap<u8, f32>, // Use u8 for binned scores (0-100)
/// Whether the calibrator has been trained
is_trained: bool,
}
impl ConfidenceCalibrator {
/// Create a new, untrained calibrator
pub fn new() -> Self {
Self {
calibration_map: HashMap::new(),
is_trained: false,
}
}
/// Train the calibrator on labeled data
///
/// # Arguments
/// * `predictions` - Raw confidence scores from the model
/// * `ground_truth` - Binary labels (1.0 if correct, 0.0 if incorrect)
pub fn train(&mut self, predictions: &[f32], ground_truth: &[f32]) -> Result<()> {
debug!(
"Training confidence calibrator on {} samples",
predictions.len()
);
if predictions.len() != ground_truth.len() {
return Err(super::OcrError::InvalidConfig(
"Predictions and ground truth must have same length".to_string(),
));
}
if predictions.is_empty() {
return Err(super::OcrError::InvalidConfig(
"Cannot train on empty data".to_string(),
));
}
// Bin the scores (0.0-1.0 -> 0-100)
let mut bins: HashMap<u8, Vec<f32>> = HashMap::new();
for (&pred, &truth) in predictions.iter().zip(ground_truth.iter()) {
let bin = (pred * 100.0).clamp(0.0, 100.0) as u8;
bins.entry(bin).or_insert_with(Vec::new).push(truth);
}
// Calculate mean accuracy for each bin
self.calibration_map.clear();
for (bin, truths) in bins {
let mean_accuracy = truths.iter().sum::<f32>() / truths.len() as f32;
self.calibration_map.insert(bin, mean_accuracy);
}
// Perform isotonic regression (simplified version)
self.enforce_monotonicity();
self.is_trained = true;
debug!(
"Calibrator trained with {} bins",
self.calibration_map.len()
);
Ok(())
}
/// Enforce monotonicity constraint (isotonic regression)
fn enforce_monotonicity(&mut self) {
let mut sorted_bins: Vec<_> = self.calibration_map.iter().collect();
sorted_bins.sort_by_key(|(bin, _)| *bin);
// Simple isotonic regression: ensure calibrated scores are non-decreasing
let mut adjusted = HashMap::new();
let mut prev_value = 0.0;
for (&bin, &value) in sorted_bins {
let adjusted_value = value.max(prev_value);
adjusted.insert(bin, adjusted_value);
prev_value = adjusted_value;
}
self.calibration_map = adjusted;
}
/// Calibrate a raw confidence score
pub fn calibrate(&self, raw_score: f32) -> f32 {
if !self.is_trained {
// If not trained, return raw score
return raw_score.clamp(0.0, 1.0);
}
let bin = (raw_score * 100.0).clamp(0.0, 100.0) as u8;
// Look up calibrated score, or interpolate
if let Some(&calibrated) = self.calibration_map.get(&bin) {
return calibrated;
}
// Interpolate between nearest bins
self.interpolate(bin)
}
/// Interpolate calibrated score for a bin without direct mapping
fn interpolate(&self, target_bin: u8) -> f32 {
let mut lower = None;
let mut upper = None;
for &bin in self.calibration_map.keys() {
if bin < target_bin {
lower = Some(lower.map_or(bin, |l: u8| l.max(bin)));
} else if bin > target_bin {
upper = Some(upper.map_or(bin, |u: u8| u.min(bin)));
}
}
match (lower, upper) {
(Some(l), Some(u)) => {
let l_val = self.calibration_map[&l];
let u_val = self.calibration_map[&u];
let alpha = (target_bin - l) as f32 / (u - l) as f32;
l_val + alpha * (u_val - l_val)
}
(Some(l), None) => self.calibration_map[&l],
(None, Some(u)) => self.calibration_map[&u],
(None, None) => target_bin as f32 / 100.0, // Fallback
}
}
/// Check if the calibrator is trained
pub fn is_trained(&self) -> bool {
self.is_trained
}
/// Reset the calibrator
pub fn reset(&mut self) {
self.calibration_map.clear();
self.is_trained = false;
}
}
impl Default for ConfidenceCalibrator {
fn default() -> Self {
Self::new()
}
}
/// Calculate Expected Calibration Error (ECE)
///
/// Measures the difference between predicted confidence and actual accuracy
pub fn calculate_ece(predictions: &[f32], ground_truth: &[f32], n_bins: usize) -> f32 {
if predictions.len() != ground_truth.len() || predictions.is_empty() {
return 0.0;
}
let mut bins: Vec<Vec<(f32, f32)>> = vec![Vec::new(); n_bins];
// Assign predictions to bins
for (&pred, &truth) in predictions.iter().zip(ground_truth.iter()) {
let bin_idx = ((pred * n_bins as f32) as usize).min(n_bins - 1);
bins[bin_idx].push((pred, truth));
}
// Calculate ECE
let mut ece = 0.0;
let total = predictions.len() as f32;
for bin in bins {
if bin.is_empty() {
continue;
}
let bin_size = bin.len() as f32;
let avg_confidence: f32 = bin.iter().map(|(p, _)| p).sum::<f32>() / bin_size;
let avg_accuracy: f32 = bin.iter().map(|(_, t)| t).sum::<f32>() / bin_size;
ece += (bin_size / total) * (avg_confidence - avg_accuracy).abs();
}
ece
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calculate_confidence() {
let logits = vec![1.0, 5.0, 2.0, 1.0];
let conf = calculate_confidence(&logits);
assert!(conf > 0.5);
assert!(conf <= 1.0);
}
#[test]
fn test_calculate_confidence_empty() {
let logits: Vec<f32> = vec![];
let conf = calculate_confidence(&logits);
assert_eq!(conf, 0.0);
}
#[test]
fn test_aggregate_confidence() {
let confidences = vec![0.9, 0.8, 0.95, 0.85];
let agg = aggregate_confidence(&confidences);
assert!(agg > 0.0 && agg <= 1.0);
assert!(agg < 0.9); // Geometric mean should be less than max
}
#[test]
fn test_aggregate_confidence_mean() {
let confidences = vec![0.8, 0.9, 0.7];
let mean = aggregate_confidence_mean(&confidences);
assert_eq!(mean, 0.8); // (0.8 + 0.9 + 0.7) / 3
}
#[test]
fn test_aggregate_confidence_min() {
let confidences = vec![0.9, 0.7, 0.95];
let min = aggregate_confidence_min(&confidences);
assert_eq!(min, 0.7);
}
#[test]
fn test_aggregate_confidence_harmonic() {
let confidences = vec![0.5, 0.5];
let harmonic = aggregate_confidence_harmonic(&confidences);
assert_eq!(harmonic, 0.5);
}
#[test]
fn test_calibrator_training() {
let mut calibrator = ConfidenceCalibrator::new();
assert!(!calibrator.is_trained());
let predictions = vec![0.9, 0.8, 0.7, 0.6, 0.5];
let ground_truth = vec![1.0, 1.0, 0.0, 1.0, 0.0];
let result = calibrator.train(&predictions, &ground_truth);
assert!(result.is_ok());
assert!(calibrator.is_trained());
}
#[test]
fn test_calibrator_calibrate() {
let mut calibrator = ConfidenceCalibrator::new();
// Before training, should return raw score
assert_eq!(calibrator.calibrate(0.8), 0.8);
// Train with some data
let predictions = vec![0.9, 0.9, 0.8, 0.8, 0.7, 0.7];
let ground_truth = vec![1.0, 1.0, 1.0, 0.0, 0.0, 0.0];
calibrator.train(&predictions, &ground_truth).unwrap();
// After training, should return calibrated score
let calibrated = calibrator.calibrate(0.85);
assert!(calibrated >= 0.0 && calibrated <= 1.0);
}
#[test]
fn test_calibrator_reset() {
let mut calibrator = ConfidenceCalibrator::new();
let predictions = vec![0.9, 0.8];
let ground_truth = vec![1.0, 0.0];
calibrator.train(&predictions, &ground_truth).unwrap();
assert!(calibrator.is_trained());
calibrator.reset();
assert!(!calibrator.is_trained());
}
#[test]
fn test_calculate_ece() {
let predictions = vec![0.9, 0.7, 0.6, 0.8];
let ground_truth = vec![1.0, 1.0, 0.0, 1.0];
let ece = calculate_ece(&predictions, &ground_truth, 3);
assert!(ece >= 0.0 && ece <= 1.0);
}
#[test]
fn test_calibrator_monotonicity() {
let mut calibrator = ConfidenceCalibrator::new();
// Create data that would violate monotonicity
let predictions = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9];
let ground_truth = vec![0.2, 0.3, 0.2, 0.5, 0.4, 0.7, 0.8, 0.9, 1.0];
calibrator.train(&predictions, &ground_truth).unwrap();
// Check monotonicity
let score1 = calibrator.calibrate(0.3);
let score2 = calibrator.calibrate(0.5);
let score3 = calibrator.calibrate(0.7);
assert!(score2 >= score1, "Calibrated scores should be monotonic");
assert!(score3 >= score2, "Calibrated scores should be monotonic");
}
}

View File

@@ -0,0 +1,441 @@
//! Output Decoding Module
//!
//! This module provides various decoding strategies for converting
//! model output logits into text strings.
use super::{OcrError, Result};
use std::collections::HashMap;
use std::sync::Arc;
use tracing::debug;
/// Decoder trait for converting logits to text
pub trait Decoder: Send + Sync {
/// Decode logits to text
fn decode(&self, logits: &[Vec<f32>]) -> Result<String>;
/// Decode with confidence scores per character
fn decode_with_confidence(&self, logits: &[Vec<f32>]) -> Result<(String, Vec<f32>)> {
// Default implementation just returns uniform confidence
let text = self.decode(logits)?;
let confidences = vec![1.0; text.len()];
Ok((text, confidences))
}
}
/// Vocabulary mapping for character recognition
#[derive(Debug, Clone)]
pub struct Vocabulary {
/// Index to character mapping
idx_to_char: HashMap<usize, char>,
/// Character to index mapping
char_to_idx: HashMap<char, usize>,
/// Blank token index for CTC
blank_idx: usize,
}
impl Vocabulary {
/// Create a new vocabulary
pub fn new(chars: Vec<char>, blank_idx: usize) -> Self {
let idx_to_char: HashMap<usize, char> =
chars.iter().enumerate().map(|(i, &c)| (i, c)).collect();
let char_to_idx: HashMap<char, usize> =
chars.iter().enumerate().map(|(i, &c)| (c, i)).collect();
Self {
idx_to_char,
char_to_idx,
blank_idx,
}
}
/// Get character by index
pub fn get_char(&self, idx: usize) -> Option<char> {
self.idx_to_char.get(&idx).copied()
}
/// Get index by character
pub fn get_idx(&self, ch: char) -> Option<usize> {
self.char_to_idx.get(&ch).copied()
}
/// Get blank token index
pub fn blank_idx(&self) -> usize {
self.blank_idx
}
/// Get vocabulary size
pub fn size(&self) -> usize {
self.idx_to_char.len()
}
}
impl Default for Vocabulary {
fn default() -> Self {
// Default vocabulary: lowercase letters + digits + space + blank
let mut chars = Vec::new();
// Add lowercase letters
for c in 'a'..='z' {
chars.push(c);
}
// Add digits
for c in '0'..='9' {
chars.push(c);
}
// Add space
chars.push(' ');
// Blank token is at the end
let blank_idx = chars.len();
Self::new(chars, blank_idx)
}
}
/// Greedy decoder - selects the character with highest probability at each step
pub struct GreedyDecoder {
vocabulary: Arc<Vocabulary>,
}
impl GreedyDecoder {
/// Create a new greedy decoder
pub fn new(vocabulary: Arc<Vocabulary>) -> Self {
Self { vocabulary }
}
/// Find the index with maximum value in a slice
fn argmax(values: &[f32]) -> usize {
values
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx)
.unwrap_or(0)
}
}
impl Decoder for GreedyDecoder {
fn decode(&self, logits: &[Vec<f32>]) -> Result<String> {
debug!("Greedy decoding {} frames", logits.len());
let mut result = String::new();
let mut prev_idx = None;
for frame_logits in logits {
let idx = Self::argmax(frame_logits);
// Skip blank tokens and repeated characters
if idx != self.vocabulary.blank_idx() && Some(idx) != prev_idx {
if let Some(ch) = self.vocabulary.get_char(idx) {
result.push(ch);
}
}
prev_idx = Some(idx);
}
Ok(result)
}
fn decode_with_confidence(&self, logits: &[Vec<f32>]) -> Result<(String, Vec<f32>)> {
let mut result = String::new();
let mut confidences = Vec::new();
let mut prev_idx = None;
for frame_logits in logits {
let idx = Self::argmax(frame_logits);
let confidence = softmax_max(frame_logits);
// Skip blank tokens and repeated characters
if idx != self.vocabulary.blank_idx() && Some(idx) != prev_idx {
if let Some(ch) = self.vocabulary.get_char(idx) {
result.push(ch);
confidences.push(confidence);
}
}
prev_idx = Some(idx);
}
Ok((result, confidences))
}
}
/// Beam search decoder - maintains top-k hypotheses for better accuracy
pub struct BeamSearchDecoder {
vocabulary: Arc<Vocabulary>,
beam_width: usize,
}
impl BeamSearchDecoder {
/// Create a new beam search decoder
pub fn new(vocabulary: Arc<Vocabulary>, beam_width: usize) -> Self {
Self {
vocabulary,
beam_width: beam_width.max(1),
}
}
/// Get beam width
pub fn beam_width(&self) -> usize {
self.beam_width
}
}
impl Decoder for BeamSearchDecoder {
fn decode(&self, logits: &[Vec<f32>]) -> Result<String> {
debug!(
"Beam search decoding {} frames (beam_width: {})",
logits.len(),
self.beam_width
);
if logits.is_empty() {
return Ok(String::new());
}
// Initialize beams: (text, score, last_idx)
let mut beams: Vec<(String, f32, Option<usize>)> = vec![(String::new(), 0.0, None)];
for frame_logits in logits {
let mut new_beams = Vec::new();
for (text, score, last_idx) in &beams {
// Get top-k predictions for this frame
let mut indexed_logits: Vec<(usize, f32)> = frame_logits
.iter()
.enumerate()
.map(|(i, &v)| (i, v))
.collect();
indexed_logits.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
// Expand each beam with top-k predictions
for (idx, logit) in indexed_logits.iter().take(self.beam_width) {
let new_score = score + logit;
// Skip blank tokens
if *idx == self.vocabulary.blank_idx() {
new_beams.push((text.clone(), new_score, Some(*idx)));
continue;
}
// Skip repeated characters (CTC collapse)
if Some(*idx) == *last_idx {
new_beams.push((text.clone(), new_score, Some(*idx)));
continue;
}
// Add character to beam
if let Some(ch) = self.vocabulary.get_char(*idx) {
let mut new_text = text.clone();
new_text.push(ch);
new_beams.push((new_text, new_score, Some(*idx)));
}
}
}
// Keep top beam_width beams
new_beams.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
new_beams.truncate(self.beam_width);
beams = new_beams;
}
// Return the best beam
Ok(beams
.first()
.map(|(text, _, _)| text.clone())
.unwrap_or_default())
}
}
/// CTC (Connectionist Temporal Classification) decoder
pub struct CTCDecoder {
vocabulary: Arc<Vocabulary>,
}
impl CTCDecoder {
/// Create a new CTC decoder
pub fn new(vocabulary: Arc<Vocabulary>) -> Self {
Self { vocabulary }
}
/// Collapse repeated characters and remove blanks
fn collapse_repeats(&self, indices: &[usize]) -> Vec<usize> {
let mut result = Vec::new();
let mut prev_idx = None;
for &idx in indices {
// Skip blanks
if idx == self.vocabulary.blank_idx() {
prev_idx = Some(idx);
continue;
}
// Skip repeats
if Some(idx) != prev_idx {
result.push(idx);
}
prev_idx = Some(idx);
}
result
}
}
impl Decoder for CTCDecoder {
fn decode(&self, logits: &[Vec<f32>]) -> Result<String> {
debug!("CTC decoding {} frames", logits.len());
// Get best path (greedy)
let indices: Vec<usize> = logits
.iter()
.map(|frame| GreedyDecoder::argmax(frame))
.collect();
// Collapse repeats and remove blanks
let collapsed = self.collapse_repeats(&indices);
// Convert to text
let text: String = collapsed
.iter()
.filter_map(|&idx| self.vocabulary.get_char(idx))
.collect();
Ok(text)
}
fn decode_with_confidence(&self, logits: &[Vec<f32>]) -> Result<(String, Vec<f32>)> {
let indices: Vec<usize> = logits
.iter()
.map(|frame| GreedyDecoder::argmax(frame))
.collect();
let confidences: Vec<f32> = logits.iter().map(|frame| softmax_max(frame)).collect();
let collapsed = self.collapse_repeats(&indices);
let text: String = collapsed
.iter()
.filter_map(|&idx| self.vocabulary.get_char(idx))
.collect();
// Map confidences to non-collapsed positions
let mut result_confidences = Vec::new();
let mut prev_idx = None;
let mut confidence_idx = 0;
for &idx in &indices {
if idx != self.vocabulary.blank_idx() && Some(idx) != prev_idx {
if confidence_idx < confidences.len() {
result_confidences.push(confidences[confidence_idx]);
}
}
confidence_idx += 1;
prev_idx = Some(idx);
}
Ok((text, result_confidences))
}
}
/// Calculate softmax and return max probability
fn softmax_max(logits: &[f32]) -> f32 {
if logits.is_empty() {
return 0.0;
}
let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp_sum: f32 = logits.iter().map(|&x| (x - max_logit).exp()).sum();
let max_exp = (logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)) - max_logit).exp();
max_exp / exp_sum
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_vocabulary() -> Arc<Vocabulary> {
Arc::new(Vocabulary::default())
}
#[test]
fn test_vocabulary_default() {
let vocab = Vocabulary::default();
assert!(vocab.size() > 0);
assert_eq!(vocab.get_char(0), Some('a'));
assert_eq!(vocab.get_idx('a'), Some(0));
}
#[test]
fn test_greedy_decoder() {
let vocab = create_test_vocabulary();
let decoder = GreedyDecoder::new(vocab.clone());
// Mock logits for "hi"
let h_idx = vocab.get_idx('h').unwrap();
let i_idx = vocab.get_idx('i').unwrap();
let blank = vocab.blank_idx();
let mut logits = vec![
vec![0.0; vocab.size() + 1],
vec![0.0; vocab.size() + 1],
vec![0.0; vocab.size() + 1],
];
logits[0][h_idx] = 10.0;
logits[1][blank] = 10.0;
logits[2][i_idx] = 10.0;
let result = decoder.decode(&logits).unwrap();
assert_eq!(result, "hi");
}
#[test]
fn test_beam_search_decoder() {
let vocab = create_test_vocabulary();
let decoder = BeamSearchDecoder::new(vocab.clone(), 3);
assert_eq!(decoder.beam_width(), 3);
let logits = vec![vec![0.0; vocab.size() + 1]; 5];
let result = decoder.decode(&logits);
assert!(result.is_ok());
}
#[test]
fn test_ctc_decoder() {
let vocab = create_test_vocabulary();
let decoder = CTCDecoder::new(vocab.clone());
// Test collapse repeats
let a_idx = vocab.get_idx('a').unwrap();
let b_idx = vocab.get_idx('b').unwrap();
let blank = vocab.blank_idx();
let indices = vec![a_idx, a_idx, blank, b_idx, b_idx, b_idx];
let collapsed = decoder.collapse_repeats(&indices);
assert_eq!(collapsed, vec![a_idx, b_idx]);
}
#[test]
fn test_softmax_max() {
let logits = vec![1.0, 2.0, 3.0, 2.0, 1.0];
let max_prob = softmax_max(&logits);
assert!(max_prob > 0.0 && max_prob <= 1.0);
assert!(max_prob > 0.5); // The max should have high probability
}
#[test]
fn test_empty_logits() {
let vocab = create_test_vocabulary();
let decoder = GreedyDecoder::new(vocab);
let empty_logits: Vec<Vec<f32>> = vec![];
let result = decoder.decode(&empty_logits).unwrap();
assert_eq!(result, "");
}
}

View File

@@ -0,0 +1,363 @@
//! OCR Engine Implementation
//!
//! This module provides the main OcrEngine for orchestrating OCR operations.
//! It handles model loading, inference coordination, and result assembly.
use super::{
confidence::aggregate_confidence,
decoder::{BeamSearchDecoder, CTCDecoder, Decoder, GreedyDecoder, Vocabulary},
inference::{DetectionResult, InferenceEngine, RecognitionResult},
models::{ModelHandle, ModelRegistry},
Character, DecoderType, OcrError, OcrOptions, OcrResult, RegionType, Result, TextRegion,
};
use parking_lot::RwLock;
use std::sync::Arc;
use std::time::Instant;
use tracing::{debug, info, warn};
/// OCR processor trait for custom implementations
pub trait OcrProcessor: Send + Sync {
/// Process an image and return OCR results
fn process(&self, image_data: &[u8], options: &OcrOptions) -> Result<OcrResult>;
/// Batch process multiple images
fn process_batch(&self, images: &[&[u8]], options: &OcrOptions) -> Result<Vec<OcrResult>>;
}
/// Main OCR Engine with thread-safe model management
pub struct OcrEngine {
/// Model registry for loading and caching models
registry: Arc<RwLock<ModelRegistry>>,
/// Inference engine for running ONNX models
inference: Arc<InferenceEngine>,
/// Default OCR options
default_options: OcrOptions,
/// Vocabulary for decoding
vocabulary: Arc<Vocabulary>,
/// Whether the engine is warmed up
warmed_up: Arc<RwLock<bool>>,
}
impl OcrEngine {
/// Create a new OCR engine with default models
///
/// # Example
///
/// ```no_run
/// # use ruvector_scipix::ocr::OcrEngine;
/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
/// let engine = OcrEngine::new().await?;
/// # Ok(())
/// # }
/// ```
pub async fn new() -> Result<Self> {
Self::with_options(OcrOptions::default()).await
}
/// Create a new OCR engine with custom options
pub async fn with_options(options: OcrOptions) -> Result<Self> {
info!("Initializing OCR engine with options: {:?}", options);
// Initialize model registry
let registry = Arc::new(RwLock::new(ModelRegistry::new()));
// Load default models (in production, these would be downloaded/cached)
debug!("Loading detection model...");
let detection_model = registry.write().load_detection_model().await.map_err(|e| {
OcrError::ModelLoading(format!("Failed to load detection model: {}", e))
})?;
debug!("Loading recognition model...");
let recognition_model = registry
.write()
.load_recognition_model()
.await
.map_err(|e| {
OcrError::ModelLoading(format!("Failed to load recognition model: {}", e))
})?;
let math_model =
if options.enable_math {
debug!("Loading math recognition model...");
Some(registry.write().load_math_model().await.map_err(|e| {
OcrError::ModelLoading(format!("Failed to load math model: {}", e))
})?)
} else {
None
};
// Create inference engine
let inference = Arc::new(InferenceEngine::new(
detection_model,
recognition_model,
math_model,
options.use_gpu,
)?);
// Load vocabulary
let vocabulary = Arc::new(Vocabulary::default());
let engine = Self {
registry,
inference,
default_options: options,
vocabulary,
warmed_up: Arc::new(RwLock::new(false)),
};
info!("OCR engine initialized successfully");
Ok(engine)
}
/// Warm up the engine by running a dummy inference
///
/// This helps reduce latency for the first real inference by initializing
/// all ONNX runtime resources.
pub async fn warmup(&self) -> Result<()> {
if *self.warmed_up.read() {
debug!("Engine already warmed up, skipping");
return Ok(());
}
info!("Warming up OCR engine...");
let start = Instant::now();
// Create a small dummy image (100x100 black image)
let dummy_image = vec![0u8; 100 * 100 * 3];
// Run a dummy inference
let _ = self.recognize(&dummy_image).await;
*self.warmed_up.write() = true;
info!("Engine warmup completed in {:?}", start.elapsed());
Ok(())
}
/// Recognize text in an image using default options
pub async fn recognize(&self, image_data: &[u8]) -> Result<OcrResult> {
self.recognize_with_options(image_data, &self.default_options)
.await
}
/// Recognize text in an image with custom options
pub async fn recognize_with_options(
&self,
image_data: &[u8],
options: &OcrOptions,
) -> Result<OcrResult> {
let start = Instant::now();
debug!("Starting OCR recognition");
// Step 1: Run text detection
debug!("Running text detection...");
let detection_results = self
.inference
.run_detection(image_data, options.detection_threshold)
.await?;
debug!("Detected {} regions", detection_results.len());
if detection_results.is_empty() {
warn!("No text regions detected");
return Ok(OcrResult {
text: String::new(),
confidence: 0.0,
regions: vec![],
has_math: false,
processing_time_ms: start.elapsed().as_millis() as u64,
});
}
// Step 2: Run recognition on each detected region
debug!("Running text recognition...");
let mut text_regions = Vec::new();
let mut has_math = false;
for detection in detection_results {
// Determine region type
let region_type = if options.enable_math && detection.is_math_likely {
has_math = true;
RegionType::Math
} else {
RegionType::Text
};
// Run appropriate recognition
let recognition = if region_type == RegionType::Math {
self.inference
.run_math_recognition(&detection.region_image, options)
.await?
} else {
self.inference
.run_recognition(&detection.region_image, options)
.await?
};
// Decode the recognition output
let decoded_text = self.decode_output(&recognition, options)?;
// Calculate confidence
let confidence = aggregate_confidence(&recognition.character_confidences);
// Filter by confidence threshold
if confidence < options.recognition_threshold {
debug!(
"Skipping region with low confidence: {:.2} < {:.2}",
confidence, options.recognition_threshold
);
continue;
}
// Build character list
let characters = decoded_text
.chars()
.zip(recognition.character_confidences.iter())
.map(|(ch, &conf)| Character {
char: ch,
confidence: conf,
bbox: None, // Could be populated if available from model
})
.collect();
text_regions.push(TextRegion {
bbox: detection.bbox,
text: decoded_text,
confidence,
region_type,
characters,
});
}
// Step 3: Combine results
let combined_text = text_regions
.iter()
.map(|r| r.text.as_str())
.collect::<Vec<_>>()
.join(" ");
let overall_confidence = if text_regions.is_empty() {
0.0
} else {
text_regions.iter().map(|r| r.confidence).sum::<f32>() / text_regions.len() as f32
};
let processing_time_ms = start.elapsed().as_millis() as u64;
debug!(
"OCR completed in {}ms, recognized {} regions",
processing_time_ms,
text_regions.len()
);
Ok(OcrResult {
text: combined_text,
confidence: overall_confidence,
regions: text_regions,
has_math,
processing_time_ms,
})
}
/// Batch process multiple images
pub async fn recognize_batch(
&self,
images: &[&[u8]],
options: &OcrOptions,
) -> Result<Vec<OcrResult>> {
info!("Processing batch of {} images", images.len());
let start = Instant::now();
// Process images in parallel using rayon
let results: Result<Vec<OcrResult>> = images
.iter()
.map(|image_data| {
// Note: In a real async implementation, we'd use tokio::spawn
// For now, we'll use blocking since we're in a sync context
futures::executor::block_on(self.recognize_with_options(image_data, options))
})
.collect();
info!("Batch processing completed in {:?}", start.elapsed());
results
}
/// Decode recognition output using the selected decoder
fn decode_output(
&self,
recognition: &RecognitionResult,
options: &OcrOptions,
) -> Result<String> {
debug!("Decoding output with {:?} decoder", options.decoder_type);
let decoded = match options.decoder_type {
DecoderType::BeamSearch => {
let decoder = BeamSearchDecoder::new(self.vocabulary.clone(), options.beam_width);
decoder.decode(&recognition.logits)?
}
DecoderType::Greedy => {
let decoder = GreedyDecoder::new(self.vocabulary.clone());
decoder.decode(&recognition.logits)?
}
DecoderType::CTC => {
let decoder = CTCDecoder::new(self.vocabulary.clone());
decoder.decode(&recognition.logits)?
}
};
Ok(decoded)
}
/// Get the current model registry
pub fn registry(&self) -> Arc<RwLock<ModelRegistry>> {
Arc::clone(&self.registry)
}
/// Get the default options
pub fn default_options(&self) -> &OcrOptions {
&self.default_options
}
/// Check if engine is warmed up
pub fn is_warmed_up(&self) -> bool {
*self.warmed_up.read()
}
}
impl OcrProcessor for OcrEngine {
fn process(&self, image_data: &[u8], options: &OcrOptions) -> Result<OcrResult> {
// Blocking wrapper for async method
futures::executor::block_on(self.recognize_with_options(image_data, options))
}
fn process_batch(&self, images: &[&[u8]], options: &OcrOptions) -> Result<Vec<OcrResult>> {
// Blocking wrapper for async method
futures::executor::block_on(self.recognize_batch(images, options))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decoder_selection() {
let options = OcrOptions {
decoder_type: DecoderType::BeamSearch,
..Default::default()
};
assert_eq!(options.decoder_type, DecoderType::BeamSearch);
}
#[test]
fn test_warmup_flag() {
let flag = Arc::new(RwLock::new(false));
assert!(!*flag.read());
*flag.write() = true;
assert!(*flag.read());
}
}

View File

@@ -0,0 +1,790 @@
//! ONNX Inference Module
//!
//! This module handles ONNX inference operations for text detection,
//! character recognition, and mathematical expression recognition.
//!
//! # Model Requirements
//!
//! This module requires ONNX models to be available in the configured model directory.
//! Without models, all inference operations will return errors.
//!
//! To use this module:
//! 1. Download compatible ONNX models (PaddleOCR, TrOCR, or similar)
//! 2. Place them in the models directory
//! 3. Enable the `ocr` feature flag
use super::{models::ModelHandle, OcrError, OcrOptions, Result};
use image::{DynamicImage, GenericImageView};
use std::sync::Arc;
use tracing::{debug, info, warn};
#[cfg(feature = "ocr")]
use ndarray::Array4;
#[cfg(feature = "ocr")]
use ort::value::Tensor;
/// Result from text detection
#[derive(Debug, Clone)]
pub struct DetectionResult {
/// Bounding box [x, y, width, height]
pub bbox: [f32; 4],
/// Detection confidence
pub confidence: f32,
/// Cropped image region
pub region_image: Vec<u8>,
/// Whether this region likely contains math
pub is_math_likely: bool,
}
/// Result from text/math recognition
#[derive(Debug, Clone)]
pub struct RecognitionResult {
/// Logits output from the model [sequence_length, vocab_size]
pub logits: Vec<Vec<f32>>,
/// Character-level confidence scores
pub character_confidences: Vec<f32>,
/// Raw output tensor (for debugging)
pub raw_output: Option<Vec<f32>>,
}
/// Inference engine for running ONNX models
///
/// IMPORTANT: This engine requires ONNX models to be loaded.
/// All methods will return errors if models are not properly initialized.
pub struct InferenceEngine {
/// Detection model
detection_model: Arc<ModelHandle>,
/// Recognition model
recognition_model: Arc<ModelHandle>,
/// Math recognition model (optional)
math_model: Option<Arc<ModelHandle>>,
/// Whether to use GPU acceleration
use_gpu: bool,
/// Whether models are actually loaded (vs placeholder handles)
models_loaded: bool,
}
impl InferenceEngine {
/// Create a new inference engine
pub fn new(
detection_model: Arc<ModelHandle>,
recognition_model: Arc<ModelHandle>,
math_model: Option<Arc<ModelHandle>>,
use_gpu: bool,
) -> Result<Self> {
// Check if models are actually loaded with ONNX sessions
let detection_loaded = detection_model.is_loaded();
let recognition_loaded = recognition_model.is_loaded();
let models_loaded = detection_loaded && recognition_loaded;
if !models_loaded {
warn!(
"ONNX models not fully loaded. Detection: {}, Recognition: {}",
detection_loaded, recognition_loaded
);
warn!("OCR inference will fail until models are properly configured.");
} else {
info!(
"Inference engine initialized with loaded models (GPU: {})",
if use_gpu { "enabled" } else { "disabled" }
);
}
Ok(Self {
detection_model,
recognition_model,
math_model,
use_gpu,
models_loaded,
})
}
/// Check if the inference engine is ready for use
pub fn is_ready(&self) -> bool {
self.models_loaded
}
/// Run text detection on an image
pub async fn run_detection(
&self,
image_data: &[u8],
threshold: f32,
) -> Result<Vec<DetectionResult>> {
if !self.models_loaded {
return Err(OcrError::ModelLoading(
"ONNX models not loaded. Please download and configure OCR models before use. \
See examples/scipix/docs/MODEL_SETUP.md for instructions."
.to_string(),
));
}
debug!("Running text detection (threshold: {})", threshold);
let input_tensor = self.preprocess_image_for_detection(image_data)?;
#[cfg(feature = "ocr")]
{
let detections = self
.run_onnx_detection(&input_tensor, threshold, image_data)
.await?;
debug!("Detected {} regions", detections.len());
return Ok(detections);
}
#[cfg(not(feature = "ocr"))]
{
Err(OcrError::Inference(
"OCR feature not enabled. Rebuild with `--features ocr` to enable ONNX inference."
.to_string(),
))
}
}
/// Run text recognition on a region image
pub async fn run_recognition(
&self,
region_image: &[u8],
options: &OcrOptions,
) -> Result<RecognitionResult> {
if !self.models_loaded {
return Err(OcrError::ModelLoading(
"ONNX models not loaded. Please download and configure OCR models before use."
.to_string(),
));
}
debug!("Running text recognition");
let input_tensor = self.preprocess_image_for_recognition(region_image)?;
#[cfg(feature = "ocr")]
{
let result = self.run_onnx_recognition(&input_tensor, options).await?;
return Ok(result);
}
#[cfg(not(feature = "ocr"))]
{
Err(OcrError::Inference(
"OCR feature not enabled. Rebuild with `--features ocr` to enable ONNX inference."
.to_string(),
))
}
}
/// Run math recognition on a region image
pub async fn run_math_recognition(
&self,
region_image: &[u8],
options: &OcrOptions,
) -> Result<RecognitionResult> {
if !self.models_loaded {
return Err(OcrError::ModelLoading(
"ONNX models not loaded. Please download and configure OCR models before use."
.to_string(),
));
}
debug!("Running math recognition");
if self.math_model.is_none() || !self.math_model.as_ref().unwrap().is_loaded() {
warn!("Math model not loaded, falling back to text recognition");
return self.run_recognition(region_image, options).await;
}
let input_tensor = self.preprocess_image_for_math(region_image)?;
#[cfg(feature = "ocr")]
{
let result = self
.run_onnx_math_recognition(&input_tensor, options)
.await?;
return Ok(result);
}
#[cfg(not(feature = "ocr"))]
{
Err(OcrError::Inference(
"OCR feature not enabled. Rebuild with `--features ocr` to enable ONNX inference."
.to_string(),
))
}
}
/// Preprocess image for detection model
fn preprocess_image_for_detection(&self, image_data: &[u8]) -> Result<Vec<f32>> {
let img = image::load_from_memory(image_data)
.map_err(|e| OcrError::ImageProcessing(format!("Failed to decode image: {}", e)))?;
let input_shape = self.detection_model.input_shape();
let (_, _, height, width) = (
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
);
let resized = img.resize_exact(
width as u32,
height as u32,
image::imageops::FilterType::Lanczos3,
);
let rgb = resized.to_rgb8();
let mut tensor = Vec::with_capacity(3 * height * width);
// Convert to NCHW format with normalization
for c in 0..3 {
for y in 0..height {
for x in 0..width {
let pixel = rgb.get_pixel(x as u32, y as u32);
tensor.push(pixel[c] as f32 / 255.0);
}
}
}
Ok(tensor)
}
/// Preprocess image for recognition model
fn preprocess_image_for_recognition(&self, image_data: &[u8]) -> Result<Vec<f32>> {
let img = image::load_from_memory(image_data)
.map_err(|e| OcrError::ImageProcessing(format!("Failed to decode image: {}", e)))?;
let input_shape = self.recognition_model.input_shape();
let (_, channels, height, width) = (
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
);
let resized = img.resize_exact(
width as u32,
height as u32,
image::imageops::FilterType::Lanczos3,
);
let mut tensor = Vec::with_capacity(channels * height * width);
if channels == 1 {
let gray = resized.to_luma8();
for y in 0..height {
for x in 0..width {
let pixel = gray.get_pixel(x as u32, y as u32);
tensor.push((pixel[0] as f32 / 127.5) - 1.0);
}
}
} else {
let rgb = resized.to_rgb8();
for c in 0..3 {
for y in 0..height {
for x in 0..width {
let pixel = rgb.get_pixel(x as u32, y as u32);
tensor.push((pixel[c] as f32 / 127.5) - 1.0);
}
}
}
}
Ok(tensor)
}
/// Preprocess image for math recognition model
fn preprocess_image_for_math(&self, image_data: &[u8]) -> Result<Vec<f32>> {
let math_model = self
.math_model
.as_ref()
.ok_or_else(|| OcrError::Inference("Math model not loaded".to_string()))?;
let img = image::load_from_memory(image_data)
.map_err(|e| OcrError::ImageProcessing(format!("Failed to decode image: {}", e)))?;
let input_shape = math_model.input_shape();
let (_, channels, height, width) = (
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
);
let resized = img.resize_exact(
width as u32,
height as u32,
image::imageops::FilterType::Lanczos3,
);
let mut tensor = Vec::with_capacity(channels * height * width);
if channels == 1 {
let gray = resized.to_luma8();
for y in 0..height {
for x in 0..width {
let pixel = gray.get_pixel(x as u32, y as u32);
tensor.push((pixel[0] as f32 / 127.5) - 1.0);
}
}
} else {
let rgb = resized.to_rgb8();
for c in 0..channels {
for y in 0..height {
for x in 0..width {
let pixel = rgb.get_pixel(x as u32, y as u32);
tensor.push((pixel[c] as f32 / 127.5) - 1.0);
}
}
}
}
Ok(tensor)
}
/// ONNX detection inference (requires `ocr` feature)
#[cfg(feature = "ocr")]
async fn run_onnx_detection(
&self,
input_tensor: &[f32],
threshold: f32,
original_image: &[u8],
) -> Result<Vec<DetectionResult>> {
let session_arc = self.detection_model.session().ok_or_else(|| {
OcrError::OnnxRuntime("Detection model session not loaded".to_string())
})?;
let mut session = session_arc.lock();
let input_shape = self.detection_model.input_shape();
let shape: Vec<usize> = input_shape.to_vec();
// Create tensor from input data
let input_array = Array4::from_shape_vec(
(shape[0], shape[1], shape[2], shape[3]),
input_tensor.to_vec(),
)
.map_err(|e| OcrError::Inference(format!("Failed to create input tensor: {}", e)))?;
// Convert to dynamic-dimension view and create ORT tensor
let input_dyn = input_array.into_dyn();
let input_tensor = Tensor::from_array(input_dyn)
.map_err(|e| OcrError::OnnxRuntime(format!("Failed to create ORT tensor: {}", e)))?;
// Run inference
let outputs = session
.run(ort::inputs![input_tensor])
.map_err(|e| OcrError::OnnxRuntime(format!("Inference failed: {}", e)))?;
let output_tensor = outputs
.iter()
.next()
.map(|(_, v)| v)
.ok_or_else(|| OcrError::OnnxRuntime("No output tensor found".to_string()))?;
let (_, raw_data) = output_tensor
.try_extract_tensor::<f32>()
.map_err(|e| OcrError::OnnxRuntime(format!("Failed to extract output: {}", e)))?;
let output_data: Vec<f32> = raw_data.to_vec();
let original_img = image::load_from_memory(original_image)
.map_err(|e| OcrError::ImageProcessing(format!("Failed to decode image: {}", e)))?;
let detections = self.parse_detection_output(&output_data, threshold, &original_img)?;
Ok(detections)
}
/// Parse detection model output
#[cfg(feature = "ocr")]
fn parse_detection_output(
&self,
output: &[f32],
threshold: f32,
original_img: &DynamicImage,
) -> Result<Vec<DetectionResult>> {
let mut results = Vec::new();
let output_shape = self.detection_model.output_shape();
if output_shape.len() >= 2 {
let num_detections = output_shape[1];
let detection_size = if output_shape.len() >= 3 {
output_shape[2]
} else {
85
};
for i in 0..num_detections {
let base_idx = i * detection_size;
if base_idx + 5 > output.len() {
break;
}
let confidence = output[base_idx + 4];
if confidence < threshold {
continue;
}
let cx = output[base_idx];
let cy = output[base_idx + 1];
let w = output[base_idx + 2];
let h = output[base_idx + 3];
let img_width = original_img.width() as f32;
let img_height = original_img.height() as f32;
let x = ((cx - w / 2.0) * img_width).max(0.0);
let y = ((cy - h / 2.0) * img_height).max(0.0);
let width = (w * img_width).min(img_width - x);
let height = (h * img_height).min(img_height - y);
if width <= 0.0 || height <= 0.0 {
continue;
}
let cropped =
original_img.crop_imm(x as u32, y as u32, width as u32, height as u32);
let mut region_bytes = Vec::new();
cropped
.write_to(
&mut std::io::Cursor::new(&mut region_bytes),
image::ImageFormat::Png,
)
.map_err(|e| {
OcrError::ImageProcessing(format!("Failed to encode region: {}", e))
})?;
let aspect_ratio = width / height;
let is_math_likely = aspect_ratio > 2.0 || aspect_ratio < 0.5;
results.push(DetectionResult {
bbox: [x, y, width, height],
confidence,
region_image: region_bytes,
is_math_likely,
});
}
}
Ok(results)
}
/// ONNX recognition inference (requires `ocr` feature)
#[cfg(feature = "ocr")]
async fn run_onnx_recognition(
&self,
input_tensor: &[f32],
_options: &OcrOptions,
) -> Result<RecognitionResult> {
let session_arc = self.recognition_model.session().ok_or_else(|| {
OcrError::OnnxRuntime("Recognition model session not loaded".to_string())
})?;
let mut session = session_arc.lock();
let input_shape = self.recognition_model.input_shape();
let shape: Vec<usize> = input_shape.to_vec();
let input_array = Array4::from_shape_vec(
(shape[0], shape[1], shape[2], shape[3]),
input_tensor.to_vec(),
)
.map_err(|e| OcrError::Inference(format!("Failed to create input tensor: {}", e)))?;
let input_dyn = input_array.into_dyn();
let input_ort = Tensor::from_array(input_dyn)
.map_err(|e| OcrError::OnnxRuntime(format!("Failed to create ORT tensor: {}", e)))?;
let outputs = session
.run(ort::inputs![input_ort])
.map_err(|e| OcrError::OnnxRuntime(format!("Recognition inference failed: {}", e)))?;
let output_tensor = outputs
.iter()
.next()
.map(|(_, v)| v)
.ok_or_else(|| OcrError::OnnxRuntime("No output tensor found".to_string()))?;
let (_, raw_data) = output_tensor
.try_extract_tensor::<f32>()
.map_err(|e| OcrError::OnnxRuntime(format!("Failed to extract output: {}", e)))?;
let output_data: Vec<f32> = raw_data.to_vec();
let output_shape = self.recognition_model.output_shape();
let seq_len = output_shape.get(1).copied().unwrap_or(26);
let vocab_size = output_shape.get(2).copied().unwrap_or(37);
let mut logits = Vec::new();
let mut character_confidences = Vec::new();
for i in 0..seq_len {
let start_idx = i * vocab_size;
let end_idx = start_idx + vocab_size;
if end_idx <= output_data.len() {
let step_logits: Vec<f32> = output_data[start_idx..end_idx].to_vec();
let max_logit = step_logits
.iter()
.cloned()
.fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = step_logits.iter().map(|&x| (x - max_logit).exp()).sum();
let softmax: Vec<f32> = step_logits
.iter()
.map(|&x| (x - max_logit).exp() / exp_sum)
.collect();
let max_confidence = softmax.iter().cloned().fold(0.0f32, f32::max);
character_confidences.push(max_confidence);
logits.push(step_logits);
}
}
Ok(RecognitionResult {
logits,
character_confidences,
raw_output: Some(output_data),
})
}
/// ONNX math recognition inference (requires `ocr` feature)
#[cfg(feature = "ocr")]
async fn run_onnx_math_recognition(
&self,
input_tensor: &[f32],
_options: &OcrOptions,
) -> Result<RecognitionResult> {
let math_model = self
.math_model
.as_ref()
.ok_or_else(|| OcrError::Inference("Math model not loaded".to_string()))?;
let session_arc = math_model
.session()
.ok_or_else(|| OcrError::OnnxRuntime("Math model session not loaded".to_string()))?;
let mut session = session_arc.lock();
let input_shape = math_model.input_shape();
let shape: Vec<usize> = input_shape.to_vec();
let input_array = Array4::from_shape_vec(
(shape[0], shape[1], shape[2], shape[3]),
input_tensor.to_vec(),
)
.map_err(|e| OcrError::Inference(format!("Failed to create input tensor: {}", e)))?;
let input_dyn = input_array.into_dyn();
let input_ort = Tensor::from_array(input_dyn)
.map_err(|e| OcrError::OnnxRuntime(format!("Failed to create ORT tensor: {}", e)))?;
let outputs = session.run(ort::inputs![input_ort]).map_err(|e| {
OcrError::OnnxRuntime(format!("Math recognition inference failed: {}", e))
})?;
let output_tensor = outputs
.iter()
.next()
.map(|(_, v)| v)
.ok_or_else(|| OcrError::OnnxRuntime("No output tensor found".to_string()))?;
let (_, raw_data) = output_tensor
.try_extract_tensor::<f32>()
.map_err(|e| OcrError::OnnxRuntime(format!("Failed to extract output: {}", e)))?;
let output_data: Vec<f32> = raw_data.to_vec();
let output_shape = math_model.output_shape();
let seq_len = output_shape.get(1).copied().unwrap_or(50);
let vocab_size = output_shape.get(2).copied().unwrap_or(512);
let mut logits = Vec::new();
let mut character_confidences = Vec::new();
for i in 0..seq_len {
let start_idx = i * vocab_size;
let end_idx = start_idx + vocab_size;
if end_idx <= output_data.len() {
let step_logits: Vec<f32> = output_data[start_idx..end_idx].to_vec();
let max_logit = step_logits
.iter()
.cloned()
.fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = step_logits.iter().map(|&x| (x - max_logit).exp()).sum();
let softmax: Vec<f32> = step_logits
.iter()
.map(|&x| (x - max_logit).exp() / exp_sum)
.collect();
let max_confidence = softmax.iter().cloned().fold(0.0f32, f32::max);
character_confidences.push(max_confidence);
logits.push(step_logits);
}
}
Ok(RecognitionResult {
logits,
character_confidences,
raw_output: Some(output_data),
})
}
/// Get detection model
pub fn detection_model(&self) -> &ModelHandle {
&self.detection_model
}
/// Get recognition model
pub fn recognition_model(&self) -> &ModelHandle {
&self.recognition_model
}
/// Get math model if available
pub fn math_model(&self) -> Option<&ModelHandle> {
self.math_model.as_ref().map(|m| m.as_ref())
}
/// Check if GPU acceleration is enabled
pub fn is_gpu_enabled(&self) -> bool {
self.use_gpu
}
}
/// Batch inference optimization
impl InferenceEngine {
/// Run batch detection on multiple images
pub async fn run_batch_detection(
&self,
images: &[&[u8]],
threshold: f32,
) -> Result<Vec<Vec<DetectionResult>>> {
if !self.models_loaded {
return Err(OcrError::ModelLoading(
"ONNX models not loaded. Cannot run batch detection.".to_string(),
));
}
debug!("Running batch detection on {} images", images.len());
let mut results = Vec::new();
for image in images {
let detections = self.run_detection(image, threshold).await?;
results.push(detections);
}
Ok(results)
}
/// Run batch recognition on multiple regions
pub async fn run_batch_recognition(
&self,
regions: &[&[u8]],
options: &OcrOptions,
) -> Result<Vec<RecognitionResult>> {
if !self.models_loaded {
return Err(OcrError::ModelLoading(
"ONNX models not loaded. Cannot run batch recognition.".to_string(),
));
}
debug!("Running batch recognition on {} regions", regions.len());
let mut results = Vec::new();
for region in regions {
let result = self.run_recognition(region, options).await?;
results.push(result);
}
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ocr::models::{ModelMetadata, ModelType};
use std::path::PathBuf;
fn create_test_model(model_type: ModelType, path: PathBuf) -> Arc<ModelHandle> {
let metadata = ModelMetadata {
name: format!("{:?} Model", model_type),
version: "1.0.0".to_string(),
input_shape: vec![1, 3, 640, 640],
output_shape: vec![1, 100, 85],
input_dtype: "float32".to_string(),
file_size: 1000,
checksum: None,
};
Arc::new(ModelHandle::new(model_type, path, metadata).unwrap())
}
#[test]
fn test_inference_engine_creation_without_models() {
let detection = create_test_model(
ModelType::Detection,
PathBuf::from("/nonexistent/model.onnx"),
);
let recognition = create_test_model(
ModelType::Recognition,
PathBuf::from("/nonexistent/model.onnx"),
);
let engine = InferenceEngine::new(detection, recognition, None, false).unwrap();
assert!(!engine.is_ready());
}
#[tokio::test]
async fn test_detection_fails_without_models() {
let detection = create_test_model(
ModelType::Detection,
PathBuf::from("/nonexistent/model.onnx"),
);
let recognition = create_test_model(
ModelType::Recognition,
PathBuf::from("/nonexistent/model.onnx"),
);
let engine = InferenceEngine::new(detection, recognition, None, false).unwrap();
let png_data = create_test_png();
let result = engine.run_detection(&png_data, 0.5).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), OcrError::ModelLoading(_)));
}
#[tokio::test]
async fn test_recognition_fails_without_models() {
let detection = create_test_model(
ModelType::Detection,
PathBuf::from("/nonexistent/model.onnx"),
);
let recognition = create_test_model(
ModelType::Recognition,
PathBuf::from("/nonexistent/model.onnx"),
);
let engine = InferenceEngine::new(detection, recognition, None, false).unwrap();
let png_data = create_test_png();
let options = OcrOptions::default();
let result = engine.run_recognition(&png_data, &options).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), OcrError::ModelLoading(_)));
}
#[test]
fn test_is_ready_reflects_model_state() {
let detection = create_test_model(ModelType::Detection, PathBuf::from("/fake/path"));
let recognition = create_test_model(ModelType::Recognition, PathBuf::from("/fake/path"));
let engine = InferenceEngine::new(detection, recognition, None, false).unwrap();
assert!(!engine.is_ready());
}
fn create_test_png() -> Vec<u8> {
use image::{ImageBuffer, RgbImage};
let img: RgbImage = ImageBuffer::from_fn(10, 10, |_, _| image::Rgb([255, 255, 255]));
let mut bytes: Vec<u8> = Vec::new();
img.write_to(
&mut std::io::Cursor::new(&mut bytes),
image::ImageFormat::Png,
)
.unwrap();
bytes
}
}

View File

@@ -0,0 +1,235 @@
//! OCR Engine Module
//!
//! This module provides optical character recognition capabilities for the ruvector-scipix system.
//! It supports text detection, character recognition, and mathematical expression recognition using
//! ONNX models for high-performance inference.
//!
//! # Architecture
//!
//! The OCR module is organized into several submodules:
//! - `engine`: Main OcrEngine for orchestrating OCR operations
//! - `models`: Model management, loading, and caching
//! - `inference`: ONNX inference operations for detection and recognition
//! - `decoder`: Output decoding strategies (beam search, greedy, CTC)
//! - `confidence`: Confidence scoring and calibration
//!
//! # Example
//!
//! ```no_run
//! use ruvector_scipix::ocr::{OcrEngine, OcrOptions};
//!
//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
//! // Initialize the OCR engine
//! let engine = OcrEngine::new().await?;
//!
//! // Load an image
//! let image_data = std::fs::read("math_formula.png")?;
//!
//! // Perform OCR
//! let result = engine.recognize(&image_data).await?;
//!
//! println!("Recognized text: {}", result.text);
//! println!("Confidence: {:.2}%", result.confidence * 100.0);
//! # Ok(())
//! # }
//! ```
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
// Submodules
mod confidence;
mod decoder;
mod engine;
mod inference;
mod models;
// Public exports
pub use confidence::{aggregate_confidence, calculate_confidence, ConfidenceCalibrator};
pub use decoder::{BeamSearchDecoder, CTCDecoder, Decoder, GreedyDecoder, Vocabulary};
pub use engine::{OcrEngine, OcrProcessor};
pub use inference::{DetectionResult, InferenceEngine, RecognitionResult};
pub use models::{ModelHandle, ModelRegistry};
/// OCR processing options
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrOptions {
/// Detection threshold for text regions (0.0-1.0)
pub detection_threshold: f32,
/// Recognition confidence threshold (0.0-1.0)
pub recognition_threshold: f32,
/// Enable mathematical expression recognition
pub enable_math: bool,
/// Decoder type to use
pub decoder_type: DecoderType,
/// Beam width for beam search decoder
pub beam_width: usize,
/// Maximum batch size for inference
pub batch_size: usize,
/// Enable GPU acceleration if available
pub use_gpu: bool,
/// Language hints for recognition
pub languages: Vec<String>,
}
impl Default for OcrOptions {
fn default() -> Self {
Self {
detection_threshold: 0.5,
recognition_threshold: 0.6,
enable_math: true,
decoder_type: DecoderType::BeamSearch,
beam_width: 5,
batch_size: 1,
use_gpu: false,
languages: vec!["en".to_string()],
}
}
}
/// Decoder type selection
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DecoderType {
/// Beam search decoder (higher quality, slower)
BeamSearch,
/// Greedy decoder (faster, lower quality)
Greedy,
/// CTC decoder for sequence-to-sequence models
CTC,
}
/// OCR result containing recognized text and metadata
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrResult {
/// Recognized text
pub text: String,
/// Overall confidence score (0.0-1.0)
pub confidence: f32,
/// Detected text regions with their bounding boxes
pub regions: Vec<TextRegion>,
/// Whether mathematical expressions were detected
pub has_math: bool,
/// Processing time in milliseconds
pub processing_time_ms: u64,
}
/// A detected text region with position and content
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextRegion {
/// Bounding box coordinates [x, y, width, height]
pub bbox: [f32; 4],
/// Recognized text in this region
pub text: String,
/// Confidence score for this region (0.0-1.0)
pub confidence: f32,
/// Region type (text, math, etc.)
pub region_type: RegionType,
/// Character-level details if available
pub characters: Vec<Character>,
}
/// Type of text region
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RegionType {
/// Regular text
Text,
/// Mathematical expression
Math,
/// Diagram or figure
Diagram,
/// Table
Table,
/// Unknown type
Unknown,
}
/// Individual character with position and confidence
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Character {
/// The character
pub char: char,
/// Confidence score (0.0-1.0)
pub confidence: f32,
/// Bounding box if available
pub bbox: Option<[f32; 4]>,
}
/// Error types for OCR operations
#[derive(Debug, thiserror::Error)]
pub enum OcrError {
#[error("Model loading error: {0}")]
ModelLoading(String),
#[error("Inference error: {0}")]
Inference(String),
#[error("Image processing error: {0}")]
ImageProcessing(String),
#[error("Decoding error: {0}")]
Decoding(String),
#[error("Invalid configuration: {0}")]
InvalidConfig(String),
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("ONNX Runtime error: {0}")]
OnnxRuntime(String),
}
pub type Result<T> = std::result::Result<T, OcrError>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ocr_options_default() {
let options = OcrOptions::default();
assert_eq!(options.detection_threshold, 0.5);
assert_eq!(options.recognition_threshold, 0.6);
assert!(options.enable_math);
assert_eq!(options.decoder_type, DecoderType::BeamSearch);
assert_eq!(options.beam_width, 5);
}
#[test]
fn test_text_region_creation() {
let region = TextRegion {
bbox: [10.0, 20.0, 100.0, 30.0],
text: "Test".to_string(),
confidence: 0.95,
region_type: RegionType::Text,
characters: vec![],
};
assert_eq!(region.bbox[0], 10.0);
assert_eq!(region.text, "Test");
assert_eq!(region.region_type, RegionType::Text);
}
#[test]
fn test_decoder_type_equality() {
assert_eq!(DecoderType::BeamSearch, DecoderType::BeamSearch);
assert_ne!(DecoderType::BeamSearch, DecoderType::Greedy);
assert_ne!(DecoderType::Greedy, DecoderType::CTC);
}
}

View File

@@ -0,0 +1,373 @@
//! Model Management Module
//!
//! This module handles loading, caching, and managing ONNX models for OCR.
//! It supports lazy loading, model downloading with progress tracking,
//! and checksum verification.
use super::{OcrError, Result};
use dashmap::DashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tracing::{debug, info, warn};
#[cfg(feature = "ocr")]
use ort::session::Session;
#[cfg(feature = "ocr")]
use parking_lot::Mutex;
/// Model types supported by the OCR engine
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ModelType {
/// Text detection model (finds text regions in images)
Detection,
/// Text recognition model (recognizes characters in regions)
Recognition,
/// Math expression recognition model
Math,
}
/// Handle to a loaded ONNX model
#[derive(Clone)]
pub struct ModelHandle {
/// Model type
model_type: ModelType,
/// Path to the model file
path: PathBuf,
/// Model metadata
metadata: ModelMetadata,
/// ONNX Runtime session (when ocr feature is enabled)
/// Wrapped in Mutex for mutable access required by ort 2.0 Session::run
#[cfg(feature = "ocr")]
session: Option<Arc<Mutex<Session>>>,
/// Mock session for when ocr feature is disabled
#[cfg(not(feature = "ocr"))]
#[allow(dead_code)]
session: Option<()>,
}
impl ModelHandle {
/// Create a new model handle
pub fn new(model_type: ModelType, path: PathBuf, metadata: ModelMetadata) -> Result<Self> {
debug!("Creating model handle for {:?} at {:?}", model_type, path);
#[cfg(feature = "ocr")]
let session = if path.exists() {
match Session::builder() {
Ok(builder) => match builder.commit_from_file(&path) {
Ok(session) => {
info!("Successfully loaded ONNX model: {:?}", path);
Some(Arc::new(Mutex::new(session)))
}
Err(e) => {
warn!("Failed to load ONNX model {:?}: {}", path, e);
None
}
},
Err(e) => {
warn!("Failed to create ONNX session builder: {}", e);
None
}
}
} else {
debug!("Model file not found: {:?}", path);
None
};
#[cfg(not(feature = "ocr"))]
let session: Option<()> = None;
Ok(Self {
model_type,
path,
metadata,
session,
})
}
/// Check if the model session is loaded
pub fn is_loaded(&self) -> bool {
self.session.is_some()
}
/// Get the ONNX session (only available with ocr feature)
#[cfg(feature = "ocr")]
pub fn session(&self) -> Option<&Arc<Mutex<Session>>> {
self.session.as_ref()
}
/// Get the model type
pub fn model_type(&self) -> ModelType {
self.model_type
}
/// Get the model path
pub fn path(&self) -> &Path {
&self.path
}
/// Get model metadata
pub fn metadata(&self) -> &ModelMetadata {
&self.metadata
}
/// Get input shape for the model
pub fn input_shape(&self) -> &[usize] {
&self.metadata.input_shape
}
/// Get output shape for the model
pub fn output_shape(&self) -> &[usize] {
&self.metadata.output_shape
}
}
/// Model metadata
#[derive(Debug, Clone)]
pub struct ModelMetadata {
/// Model name
pub name: String,
/// Model version
pub version: String,
/// Input tensor shape
pub input_shape: Vec<usize>,
/// Output tensor shape
pub output_shape: Vec<usize>,
/// Expected input data type
pub input_dtype: String,
/// File size in bytes
pub file_size: u64,
/// SHA256 checksum
pub checksum: Option<String>,
}
/// Model registry for loading and caching models
pub struct ModelRegistry {
/// Cache of loaded models
cache: DashMap<ModelType, Arc<ModelHandle>>,
/// Base directory for models
model_dir: PathBuf,
/// Whether to enable lazy loading
lazy_loading: bool,
}
impl ModelRegistry {
/// Create a new model registry
pub fn new() -> Self {
Self::with_model_dir(PathBuf::from("./models"))
}
/// Create a new model registry with custom model directory
pub fn with_model_dir(model_dir: PathBuf) -> Self {
info!("Initializing model registry at {:?}", model_dir);
Self {
cache: DashMap::new(),
model_dir,
lazy_loading: true,
}
}
/// Load the detection model
pub async fn load_detection_model(&mut self) -> Result<Arc<ModelHandle>> {
self.load_model(ModelType::Detection).await
}
/// Load the recognition model
pub async fn load_recognition_model(&mut self) -> Result<Arc<ModelHandle>> {
self.load_model(ModelType::Recognition).await
}
/// Load the math recognition model
pub async fn load_math_model(&mut self) -> Result<Arc<ModelHandle>> {
self.load_model(ModelType::Math).await
}
/// Load a model by type
pub async fn load_model(&mut self, model_type: ModelType) -> Result<Arc<ModelHandle>> {
// Check cache first
if let Some(handle) = self.cache.get(&model_type) {
debug!("Model {:?} found in cache", model_type);
return Ok(Arc::clone(handle.value()));
}
info!("Loading model {:?}...", model_type);
// Get model path
let model_path = self.get_model_path(model_type);
// Check if model exists
if !model_path.exists() {
if self.lazy_loading {
warn!(
"Model {:?} not found at {:?}. OCR will not work without models.",
model_type, model_path
);
warn!("Download models from: https://github.com/PaddlePaddle/PaddleOCR or configure custom models.");
} else {
return Err(OcrError::ModelLoading(format!(
"Model {:?} not found at {:?}",
model_type, model_path
)));
}
}
// Load model metadata
let metadata = self.get_model_metadata(model_type);
// Verify checksum if provided
if let Some(ref checksum) = metadata.checksum {
if model_path.exists() {
debug!("Verifying model checksum: {}", checksum);
// In production: verify_checksum(&model_path, checksum)?;
}
}
// Create model handle (will load ONNX session if file exists)
let handle = Arc::new(ModelHandle::new(model_type, model_path, metadata)?);
// Cache the handle
self.cache.insert(model_type, Arc::clone(&handle));
if handle.is_loaded() {
info!(
"Model {:?} loaded successfully with ONNX session",
model_type
);
} else {
warn!(
"Model {:?} handle created but ONNX session not loaded",
model_type
);
}
Ok(handle)
}
/// Get the file path for a model type
fn get_model_path(&self, model_type: ModelType) -> PathBuf {
let filename = match model_type {
ModelType::Detection => "text_detection.onnx",
ModelType::Recognition => "text_recognition.onnx",
ModelType::Math => "math_recognition.onnx",
};
self.model_dir.join(filename)
}
/// Get default metadata for a model type
fn get_model_metadata(&self, model_type: ModelType) -> ModelMetadata {
match model_type {
ModelType::Detection => ModelMetadata {
name: "Text Detection".to_string(),
version: "1.0.0".to_string(),
input_shape: vec![1, 3, 640, 640], // NCHW format
output_shape: vec![1, 25200, 85], // Detections
input_dtype: "float32".to_string(),
file_size: 50_000_000, // ~50MB
checksum: None,
},
ModelType::Recognition => ModelMetadata {
name: "Text Recognition".to_string(),
version: "1.0.0".to_string(),
input_shape: vec![1, 1, 32, 128], // NCHW format
output_shape: vec![1, 26, 37], // Sequence length, vocab size
input_dtype: "float32".to_string(),
file_size: 20_000_000, // ~20MB
checksum: None,
},
ModelType::Math => ModelMetadata {
name: "Math Recognition".to_string(),
version: "1.0.0".to_string(),
input_shape: vec![1, 1, 64, 256], // NCHW format
output_shape: vec![1, 50, 512], // Sequence length, vocab size
input_dtype: "float32".to_string(),
file_size: 80_000_000, // ~80MB
checksum: None,
},
}
}
/// Clear the model cache
pub fn clear_cache(&mut self) {
info!("Clearing model cache");
self.cache.clear();
}
/// Get a cached model if available
pub fn get_cached(&self, model_type: ModelType) -> Option<Arc<ModelHandle>> {
self.cache.get(&model_type).map(|h| Arc::clone(h.value()))
}
/// Set lazy loading mode
pub fn set_lazy_loading(&mut self, enabled: bool) {
self.lazy_loading = enabled;
}
/// Get the model directory
pub fn model_dir(&self) -> &Path {
&self.model_dir
}
}
impl Default for ModelRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_registry_creation() {
let registry = ModelRegistry::new();
assert_eq!(registry.model_dir(), Path::new("./models"));
assert!(registry.lazy_loading);
}
#[test]
fn test_model_path_generation() {
let registry = ModelRegistry::new();
let path = registry.get_model_path(ModelType::Detection);
assert!(path.to_string_lossy().contains("text_detection.onnx"));
}
#[test]
fn test_model_metadata() {
let registry = ModelRegistry::new();
let metadata = registry.get_model_metadata(ModelType::Recognition);
assert_eq!(metadata.name, "Text Recognition");
assert_eq!(metadata.version, "1.0.0");
assert_eq!(metadata.input_shape, vec![1, 1, 32, 128]);
}
#[tokio::test]
async fn test_model_caching() {
let mut registry = ModelRegistry::new();
let model1 = registry.load_detection_model().await.unwrap();
let model2 = registry.load_detection_model().await.unwrap();
assert!(Arc::ptr_eq(&model1, &model2));
}
#[test]
fn test_clear_cache() {
let mut registry = ModelRegistry::new();
registry.clear_cache();
assert_eq!(registry.cache.len(), 0);
}
#[test]
fn test_model_handle_without_file() {
let path = PathBuf::from("/nonexistent/model.onnx");
let metadata = ModelMetadata {
name: "Test".to_string(),
version: "1.0.0".to_string(),
input_shape: vec![1, 3, 640, 640],
output_shape: vec![1, 100, 85],
input_dtype: "float32".to_string(),
file_size: 1000,
checksum: None,
};
let handle = ModelHandle::new(ModelType::Detection, path, metadata).unwrap();
assert!(!handle.is_loaded());
}
}