Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
384
vendor/ruvector/examples/scipix/src/ocr/confidence.rs
vendored
Normal file
384
vendor/ruvector/examples/scipix/src/ocr/confidence.rs
vendored
Normal file
@@ -0,0 +1,384 @@
|
||||
//! Confidence Scoring Module
|
||||
//!
|
||||
//! This module provides confidence scoring and calibration for OCR results.
|
||||
//! It includes per-character confidence calculation and aggregation methods.
|
||||
|
||||
use super::Result;
|
||||
use std::collections::HashMap;
|
||||
use tracing::debug;
|
||||
|
||||
/// Calculate confidence score for a single character prediction
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `logits` - Raw logits from the model for this character position
|
||||
///
|
||||
/// # Returns
|
||||
/// Confidence score between 0.0 and 1.0
|
||||
pub fn calculate_confidence(logits: &[f32]) -> f32 {
|
||||
if logits.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Apply softmax to get probabilities
|
||||
let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
|
||||
let exp_sum: f32 = logits.iter().map(|&x| (x - max_logit).exp()).sum();
|
||||
|
||||
// Return the maximum probability
|
||||
let max_prob = logits
|
||||
.iter()
|
||||
.map(|&x| (x - max_logit).exp() / exp_sum)
|
||||
.fold(0.0f32, |a, b| a.max(b));
|
||||
|
||||
max_prob.clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
/// Aggregate multiple confidence scores into a single score
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `confidences` - Individual confidence scores
|
||||
///
|
||||
/// # Returns
|
||||
/// Aggregated confidence score using geometric mean
|
||||
pub fn aggregate_confidence(confidences: &[f32]) -> f32 {
|
||||
if confidences.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Use geometric mean for aggregation (more conservative than arithmetic mean)
|
||||
let product: f32 = confidences.iter().product();
|
||||
let n = confidences.len() as f32;
|
||||
product.powf(1.0 / n).clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
/// Alternative aggregation using arithmetic mean
|
||||
pub fn aggregate_confidence_mean(confidences: &[f32]) -> f32 {
|
||||
if confidences.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let sum: f32 = confidences.iter().sum();
|
||||
(sum / confidences.len() as f32).clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
/// Alternative aggregation using minimum (most conservative)
|
||||
pub fn aggregate_confidence_min(confidences: &[f32]) -> f32 {
|
||||
confidences
|
||||
.iter()
|
||||
.fold(1.0f32, |a, &b| a.min(b))
|
||||
.clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
/// Alternative aggregation using harmonic mean
|
||||
pub fn aggregate_confidence_harmonic(confidences: &[f32]) -> f32 {
|
||||
if confidences.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let sum_reciprocals: f32 = confidences.iter().map(|&c| 1.0 / c.max(0.001)).sum();
|
||||
let n = confidences.len() as f32;
|
||||
(n / sum_reciprocals).clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
/// Confidence calibrator using isotonic regression
|
||||
///
|
||||
/// This calibrator learns a mapping from raw confidence scores to calibrated
|
||||
/// probabilities using historical data.
|
||||
pub struct ConfidenceCalibrator {
|
||||
/// Calibration mapping: raw_score -> calibrated_score
|
||||
calibration_map: HashMap<u8, f32>, // Use u8 for binned scores (0-100)
|
||||
/// Whether the calibrator has been trained
|
||||
is_trained: bool,
|
||||
}
|
||||
|
||||
impl ConfidenceCalibrator {
|
||||
/// Create a new, untrained calibrator
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
calibration_map: HashMap::new(),
|
||||
is_trained: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Train the calibrator on labeled data
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `predictions` - Raw confidence scores from the model
|
||||
/// * `ground_truth` - Binary labels (1.0 if correct, 0.0 if incorrect)
|
||||
pub fn train(&mut self, predictions: &[f32], ground_truth: &[f32]) -> Result<()> {
|
||||
debug!(
|
||||
"Training confidence calibrator on {} samples",
|
||||
predictions.len()
|
||||
);
|
||||
|
||||
if predictions.len() != ground_truth.len() {
|
||||
return Err(super::OcrError::InvalidConfig(
|
||||
"Predictions and ground truth must have same length".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if predictions.is_empty() {
|
||||
return Err(super::OcrError::InvalidConfig(
|
||||
"Cannot train on empty data".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Bin the scores (0.0-1.0 -> 0-100)
|
||||
let mut bins: HashMap<u8, Vec<f32>> = HashMap::new();
|
||||
|
||||
for (&pred, &truth) in predictions.iter().zip(ground_truth.iter()) {
|
||||
let bin = (pred * 100.0).clamp(0.0, 100.0) as u8;
|
||||
bins.entry(bin).or_insert_with(Vec::new).push(truth);
|
||||
}
|
||||
|
||||
// Calculate mean accuracy for each bin
|
||||
self.calibration_map.clear();
|
||||
for (bin, truths) in bins {
|
||||
let mean_accuracy = truths.iter().sum::<f32>() / truths.len() as f32;
|
||||
self.calibration_map.insert(bin, mean_accuracy);
|
||||
}
|
||||
|
||||
// Perform isotonic regression (simplified version)
|
||||
self.enforce_monotonicity();
|
||||
|
||||
self.is_trained = true;
|
||||
debug!(
|
||||
"Calibrator trained with {} bins",
|
||||
self.calibration_map.len()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Enforce monotonicity constraint (isotonic regression)
|
||||
fn enforce_monotonicity(&mut self) {
|
||||
let mut sorted_bins: Vec<_> = self.calibration_map.iter().collect();
|
||||
sorted_bins.sort_by_key(|(bin, _)| *bin);
|
||||
|
||||
// Simple isotonic regression: ensure calibrated scores are non-decreasing
|
||||
let mut adjusted = HashMap::new();
|
||||
let mut prev_value = 0.0;
|
||||
|
||||
for (&bin, &value) in sorted_bins {
|
||||
let adjusted_value = value.max(prev_value);
|
||||
adjusted.insert(bin, adjusted_value);
|
||||
prev_value = adjusted_value;
|
||||
}
|
||||
|
||||
self.calibration_map = adjusted;
|
||||
}
|
||||
|
||||
/// Calibrate a raw confidence score
|
||||
pub fn calibrate(&self, raw_score: f32) -> f32 {
|
||||
if !self.is_trained {
|
||||
// If not trained, return raw score
|
||||
return raw_score.clamp(0.0, 1.0);
|
||||
}
|
||||
|
||||
let bin = (raw_score * 100.0).clamp(0.0, 100.0) as u8;
|
||||
|
||||
// Look up calibrated score, or interpolate
|
||||
if let Some(&calibrated) = self.calibration_map.get(&bin) {
|
||||
return calibrated;
|
||||
}
|
||||
|
||||
// Interpolate between nearest bins
|
||||
self.interpolate(bin)
|
||||
}
|
||||
|
||||
/// Interpolate calibrated score for a bin without direct mapping
|
||||
fn interpolate(&self, target_bin: u8) -> f32 {
|
||||
let mut lower = None;
|
||||
let mut upper = None;
|
||||
|
||||
for &bin in self.calibration_map.keys() {
|
||||
if bin < target_bin {
|
||||
lower = Some(lower.map_or(bin, |l: u8| l.max(bin)));
|
||||
} else if bin > target_bin {
|
||||
upper = Some(upper.map_or(bin, |u: u8| u.min(bin)));
|
||||
}
|
||||
}
|
||||
|
||||
match (lower, upper) {
|
||||
(Some(l), Some(u)) => {
|
||||
let l_val = self.calibration_map[&l];
|
||||
let u_val = self.calibration_map[&u];
|
||||
let alpha = (target_bin - l) as f32 / (u - l) as f32;
|
||||
l_val + alpha * (u_val - l_val)
|
||||
}
|
||||
(Some(l), None) => self.calibration_map[&l],
|
||||
(None, Some(u)) => self.calibration_map[&u],
|
||||
(None, None) => target_bin as f32 / 100.0, // Fallback
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the calibrator is trained
|
||||
pub fn is_trained(&self) -> bool {
|
||||
self.is_trained
|
||||
}
|
||||
|
||||
/// Reset the calibrator
|
||||
pub fn reset(&mut self) {
|
||||
self.calibration_map.clear();
|
||||
self.is_trained = false;
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ConfidenceCalibrator {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate Expected Calibration Error (ECE)
|
||||
///
|
||||
/// Measures the difference between predicted confidence and actual accuracy
|
||||
pub fn calculate_ece(predictions: &[f32], ground_truth: &[f32], n_bins: usize) -> f32 {
|
||||
if predictions.len() != ground_truth.len() || predictions.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mut bins: Vec<Vec<(f32, f32)>> = vec![Vec::new(); n_bins];
|
||||
|
||||
// Assign predictions to bins
|
||||
for (&pred, &truth) in predictions.iter().zip(ground_truth.iter()) {
|
||||
let bin_idx = ((pred * n_bins as f32) as usize).min(n_bins - 1);
|
||||
bins[bin_idx].push((pred, truth));
|
||||
}
|
||||
|
||||
// Calculate ECE
|
||||
let mut ece = 0.0;
|
||||
let total = predictions.len() as f32;
|
||||
|
||||
for bin in bins {
|
||||
if bin.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let bin_size = bin.len() as f32;
|
||||
let avg_confidence: f32 = bin.iter().map(|(p, _)| p).sum::<f32>() / bin_size;
|
||||
let avg_accuracy: f32 = bin.iter().map(|(_, t)| t).sum::<f32>() / bin_size;
|
||||
|
||||
ece += (bin_size / total) * (avg_confidence - avg_accuracy).abs();
|
||||
}
|
||||
|
||||
ece
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_calculate_confidence() {
|
||||
let logits = vec![1.0, 5.0, 2.0, 1.0];
|
||||
let conf = calculate_confidence(&logits);
|
||||
assert!(conf > 0.5);
|
||||
assert!(conf <= 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calculate_confidence_empty() {
|
||||
let logits: Vec<f32> = vec![];
|
||||
let conf = calculate_confidence(&logits);
|
||||
assert_eq!(conf, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregate_confidence() {
|
||||
let confidences = vec![0.9, 0.8, 0.95, 0.85];
|
||||
let agg = aggregate_confidence(&confidences);
|
||||
assert!(agg > 0.0 && agg <= 1.0);
|
||||
assert!(agg < 0.9); // Geometric mean should be less than max
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregate_confidence_mean() {
|
||||
let confidences = vec![0.8, 0.9, 0.7];
|
||||
let mean = aggregate_confidence_mean(&confidences);
|
||||
assert_eq!(mean, 0.8); // (0.8 + 0.9 + 0.7) / 3
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregate_confidence_min() {
|
||||
let confidences = vec![0.9, 0.7, 0.95];
|
||||
let min = aggregate_confidence_min(&confidences);
|
||||
assert_eq!(min, 0.7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregate_confidence_harmonic() {
|
||||
let confidences = vec![0.5, 0.5];
|
||||
let harmonic = aggregate_confidence_harmonic(&confidences);
|
||||
assert_eq!(harmonic, 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calibrator_training() {
|
||||
let mut calibrator = ConfidenceCalibrator::new();
|
||||
assert!(!calibrator.is_trained());
|
||||
|
||||
let predictions = vec![0.9, 0.8, 0.7, 0.6, 0.5];
|
||||
let ground_truth = vec![1.0, 1.0, 0.0, 1.0, 0.0];
|
||||
|
||||
let result = calibrator.train(&predictions, &ground_truth);
|
||||
assert!(result.is_ok());
|
||||
assert!(calibrator.is_trained());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calibrator_calibrate() {
|
||||
let mut calibrator = ConfidenceCalibrator::new();
|
||||
|
||||
// Before training, should return raw score
|
||||
assert_eq!(calibrator.calibrate(0.8), 0.8);
|
||||
|
||||
// Train with some data
|
||||
let predictions = vec![0.9, 0.9, 0.8, 0.8, 0.7, 0.7];
|
||||
let ground_truth = vec![1.0, 1.0, 1.0, 0.0, 0.0, 0.0];
|
||||
calibrator.train(&predictions, &ground_truth).unwrap();
|
||||
|
||||
// After training, should return calibrated score
|
||||
let calibrated = calibrator.calibrate(0.85);
|
||||
assert!(calibrated >= 0.0 && calibrated <= 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calibrator_reset() {
|
||||
let mut calibrator = ConfidenceCalibrator::new();
|
||||
let predictions = vec![0.9, 0.8];
|
||||
let ground_truth = vec![1.0, 0.0];
|
||||
|
||||
calibrator.train(&predictions, &ground_truth).unwrap();
|
||||
assert!(calibrator.is_trained());
|
||||
|
||||
calibrator.reset();
|
||||
assert!(!calibrator.is_trained());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calculate_ece() {
|
||||
let predictions = vec![0.9, 0.7, 0.6, 0.8];
|
||||
let ground_truth = vec![1.0, 1.0, 0.0, 1.0];
|
||||
let ece = calculate_ece(&predictions, &ground_truth, 3);
|
||||
assert!(ece >= 0.0 && ece <= 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calibrator_monotonicity() {
|
||||
let mut calibrator = ConfidenceCalibrator::new();
|
||||
|
||||
// Create data that would violate monotonicity
|
||||
let predictions = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9];
|
||||
let ground_truth = vec![0.2, 0.3, 0.2, 0.5, 0.4, 0.7, 0.8, 0.9, 1.0];
|
||||
|
||||
calibrator.train(&predictions, &ground_truth).unwrap();
|
||||
|
||||
// Check monotonicity
|
||||
let score1 = calibrator.calibrate(0.3);
|
||||
let score2 = calibrator.calibrate(0.5);
|
||||
let score3 = calibrator.calibrate(0.7);
|
||||
|
||||
assert!(score2 >= score1, "Calibrated scores should be monotonic");
|
||||
assert!(score3 >= score2, "Calibrated scores should be monotonic");
|
||||
}
|
||||
}
|
||||
441
vendor/ruvector/examples/scipix/src/ocr/decoder.rs
vendored
Normal file
441
vendor/ruvector/examples/scipix/src/ocr/decoder.rs
vendored
Normal file
@@ -0,0 +1,441 @@
|
||||
//! Output Decoding Module
|
||||
//!
|
||||
//! This module provides various decoding strategies for converting
|
||||
//! model output logits into text strings.
|
||||
|
||||
use super::{OcrError, Result};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tracing::debug;
|
||||
|
||||
/// Decoder trait for converting logits to text
|
||||
pub trait Decoder: Send + Sync {
|
||||
/// Decode logits to text
|
||||
fn decode(&self, logits: &[Vec<f32>]) -> Result<String>;
|
||||
|
||||
/// Decode with confidence scores per character
|
||||
fn decode_with_confidence(&self, logits: &[Vec<f32>]) -> Result<(String, Vec<f32>)> {
|
||||
// Default implementation just returns uniform confidence
|
||||
let text = self.decode(logits)?;
|
||||
let confidences = vec![1.0; text.len()];
|
||||
Ok((text, confidences))
|
||||
}
|
||||
}
|
||||
|
||||
/// Vocabulary mapping for character recognition
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Vocabulary {
|
||||
/// Index to character mapping
|
||||
idx_to_char: HashMap<usize, char>,
|
||||
/// Character to index mapping
|
||||
char_to_idx: HashMap<char, usize>,
|
||||
/// Blank token index for CTC
|
||||
blank_idx: usize,
|
||||
}
|
||||
|
||||
impl Vocabulary {
|
||||
/// Create a new vocabulary
|
||||
pub fn new(chars: Vec<char>, blank_idx: usize) -> Self {
|
||||
let idx_to_char: HashMap<usize, char> =
|
||||
chars.iter().enumerate().map(|(i, &c)| (i, c)).collect();
|
||||
let char_to_idx: HashMap<char, usize> =
|
||||
chars.iter().enumerate().map(|(i, &c)| (c, i)).collect();
|
||||
|
||||
Self {
|
||||
idx_to_char,
|
||||
char_to_idx,
|
||||
blank_idx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get character by index
|
||||
pub fn get_char(&self, idx: usize) -> Option<char> {
|
||||
self.idx_to_char.get(&idx).copied()
|
||||
}
|
||||
|
||||
/// Get index by character
|
||||
pub fn get_idx(&self, ch: char) -> Option<usize> {
|
||||
self.char_to_idx.get(&ch).copied()
|
||||
}
|
||||
|
||||
/// Get blank token index
|
||||
pub fn blank_idx(&self) -> usize {
|
||||
self.blank_idx
|
||||
}
|
||||
|
||||
/// Get vocabulary size
|
||||
pub fn size(&self) -> usize {
|
||||
self.idx_to_char.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Vocabulary {
|
||||
fn default() -> Self {
|
||||
// Default vocabulary: lowercase letters + digits + space + blank
|
||||
let mut chars = Vec::new();
|
||||
|
||||
// Add lowercase letters
|
||||
for c in 'a'..='z' {
|
||||
chars.push(c);
|
||||
}
|
||||
|
||||
// Add digits
|
||||
for c in '0'..='9' {
|
||||
chars.push(c);
|
||||
}
|
||||
|
||||
// Add space
|
||||
chars.push(' ');
|
||||
|
||||
// Blank token is at the end
|
||||
let blank_idx = chars.len();
|
||||
|
||||
Self::new(chars, blank_idx)
|
||||
}
|
||||
}
|
||||
|
||||
/// Greedy decoder - selects the character with highest probability at each step
|
||||
pub struct GreedyDecoder {
|
||||
vocabulary: Arc<Vocabulary>,
|
||||
}
|
||||
|
||||
impl GreedyDecoder {
|
||||
/// Create a new greedy decoder
|
||||
pub fn new(vocabulary: Arc<Vocabulary>) -> Self {
|
||||
Self { vocabulary }
|
||||
}
|
||||
|
||||
/// Find the index with maximum value in a slice
|
||||
fn argmax(values: &[f32]) -> usize {
|
||||
values
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.map(|(idx, _)| idx)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for GreedyDecoder {
|
||||
fn decode(&self, logits: &[Vec<f32>]) -> Result<String> {
|
||||
debug!("Greedy decoding {} frames", logits.len());
|
||||
|
||||
let mut result = String::new();
|
||||
let mut prev_idx = None;
|
||||
|
||||
for frame_logits in logits {
|
||||
let idx = Self::argmax(frame_logits);
|
||||
|
||||
// Skip blank tokens and repeated characters
|
||||
if idx != self.vocabulary.blank_idx() && Some(idx) != prev_idx {
|
||||
if let Some(ch) = self.vocabulary.get_char(idx) {
|
||||
result.push(ch);
|
||||
}
|
||||
}
|
||||
|
||||
prev_idx = Some(idx);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn decode_with_confidence(&self, logits: &[Vec<f32>]) -> Result<(String, Vec<f32>)> {
|
||||
let mut result = String::new();
|
||||
let mut confidences = Vec::new();
|
||||
let mut prev_idx = None;
|
||||
|
||||
for frame_logits in logits {
|
||||
let idx = Self::argmax(frame_logits);
|
||||
let confidence = softmax_max(frame_logits);
|
||||
|
||||
// Skip blank tokens and repeated characters
|
||||
if idx != self.vocabulary.blank_idx() && Some(idx) != prev_idx {
|
||||
if let Some(ch) = self.vocabulary.get_char(idx) {
|
||||
result.push(ch);
|
||||
confidences.push(confidence);
|
||||
}
|
||||
}
|
||||
|
||||
prev_idx = Some(idx);
|
||||
}
|
||||
|
||||
Ok((result, confidences))
|
||||
}
|
||||
}
|
||||
|
||||
/// Beam search decoder - maintains top-k hypotheses for better accuracy
|
||||
pub struct BeamSearchDecoder {
|
||||
vocabulary: Arc<Vocabulary>,
|
||||
beam_width: usize,
|
||||
}
|
||||
|
||||
impl BeamSearchDecoder {
|
||||
/// Create a new beam search decoder
|
||||
pub fn new(vocabulary: Arc<Vocabulary>, beam_width: usize) -> Self {
|
||||
Self {
|
||||
vocabulary,
|
||||
beam_width: beam_width.max(1),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get beam width
|
||||
pub fn beam_width(&self) -> usize {
|
||||
self.beam_width
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for BeamSearchDecoder {
|
||||
fn decode(&self, logits: &[Vec<f32>]) -> Result<String> {
|
||||
debug!(
|
||||
"Beam search decoding {} frames (beam_width: {})",
|
||||
logits.len(),
|
||||
self.beam_width
|
||||
);
|
||||
|
||||
if logits.is_empty() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
|
||||
// Initialize beams: (text, score, last_idx)
|
||||
let mut beams: Vec<(String, f32, Option<usize>)> = vec![(String::new(), 0.0, None)];
|
||||
|
||||
for frame_logits in logits {
|
||||
let mut new_beams = Vec::new();
|
||||
|
||||
for (text, score, last_idx) in &beams {
|
||||
// Get top-k predictions for this frame
|
||||
let mut indexed_logits: Vec<(usize, f32)> = frame_logits
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &v)| (i, v))
|
||||
.collect();
|
||||
indexed_logits.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
|
||||
// Expand each beam with top-k predictions
|
||||
for (idx, logit) in indexed_logits.iter().take(self.beam_width) {
|
||||
let new_score = score + logit;
|
||||
|
||||
// Skip blank tokens
|
||||
if *idx == self.vocabulary.blank_idx() {
|
||||
new_beams.push((text.clone(), new_score, Some(*idx)));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip repeated characters (CTC collapse)
|
||||
if Some(*idx) == *last_idx {
|
||||
new_beams.push((text.clone(), new_score, Some(*idx)));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Add character to beam
|
||||
if let Some(ch) = self.vocabulary.get_char(*idx) {
|
||||
let mut new_text = text.clone();
|
||||
new_text.push(ch);
|
||||
new_beams.push((new_text, new_score, Some(*idx)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Keep top beam_width beams
|
||||
new_beams.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
new_beams.truncate(self.beam_width);
|
||||
beams = new_beams;
|
||||
}
|
||||
|
||||
// Return the best beam
|
||||
Ok(beams
|
||||
.first()
|
||||
.map(|(text, _, _)| text.clone())
|
||||
.unwrap_or_default())
|
||||
}
|
||||
}
|
||||
|
||||
/// CTC (Connectionist Temporal Classification) decoder
|
||||
pub struct CTCDecoder {
|
||||
vocabulary: Arc<Vocabulary>,
|
||||
}
|
||||
|
||||
impl CTCDecoder {
|
||||
/// Create a new CTC decoder
|
||||
pub fn new(vocabulary: Arc<Vocabulary>) -> Self {
|
||||
Self { vocabulary }
|
||||
}
|
||||
|
||||
/// Collapse repeated characters and remove blanks
|
||||
fn collapse_repeats(&self, indices: &[usize]) -> Vec<usize> {
|
||||
let mut result = Vec::new();
|
||||
let mut prev_idx = None;
|
||||
|
||||
for &idx in indices {
|
||||
// Skip blanks
|
||||
if idx == self.vocabulary.blank_idx() {
|
||||
prev_idx = Some(idx);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip repeats
|
||||
if Some(idx) != prev_idx {
|
||||
result.push(idx);
|
||||
}
|
||||
|
||||
prev_idx = Some(idx);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for CTCDecoder {
|
||||
fn decode(&self, logits: &[Vec<f32>]) -> Result<String> {
|
||||
debug!("CTC decoding {} frames", logits.len());
|
||||
|
||||
// Get best path (greedy)
|
||||
let indices: Vec<usize> = logits
|
||||
.iter()
|
||||
.map(|frame| GreedyDecoder::argmax(frame))
|
||||
.collect();
|
||||
|
||||
// Collapse repeats and remove blanks
|
||||
let collapsed = self.collapse_repeats(&indices);
|
||||
|
||||
// Convert to text
|
||||
let text: String = collapsed
|
||||
.iter()
|
||||
.filter_map(|&idx| self.vocabulary.get_char(idx))
|
||||
.collect();
|
||||
|
||||
Ok(text)
|
||||
}
|
||||
|
||||
fn decode_with_confidence(&self, logits: &[Vec<f32>]) -> Result<(String, Vec<f32>)> {
|
||||
let indices: Vec<usize> = logits
|
||||
.iter()
|
||||
.map(|frame| GreedyDecoder::argmax(frame))
|
||||
.collect();
|
||||
let confidences: Vec<f32> = logits.iter().map(|frame| softmax_max(frame)).collect();
|
||||
|
||||
let collapsed = self.collapse_repeats(&indices);
|
||||
|
||||
let text: String = collapsed
|
||||
.iter()
|
||||
.filter_map(|&idx| self.vocabulary.get_char(idx))
|
||||
.collect();
|
||||
|
||||
// Map confidences to non-collapsed positions
|
||||
let mut result_confidences = Vec::new();
|
||||
let mut prev_idx = None;
|
||||
let mut confidence_idx = 0;
|
||||
|
||||
for &idx in &indices {
|
||||
if idx != self.vocabulary.blank_idx() && Some(idx) != prev_idx {
|
||||
if confidence_idx < confidences.len() {
|
||||
result_confidences.push(confidences[confidence_idx]);
|
||||
}
|
||||
}
|
||||
confidence_idx += 1;
|
||||
prev_idx = Some(idx);
|
||||
}
|
||||
|
||||
Ok((text, result_confidences))
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate softmax and return max probability
|
||||
fn softmax_max(logits: &[f32]) -> f32 {
|
||||
if logits.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
|
||||
let exp_sum: f32 = logits.iter().map(|&x| (x - max_logit).exp()).sum();
|
||||
|
||||
let max_exp = (logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)) - max_logit).exp();
|
||||
max_exp / exp_sum
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_test_vocabulary() -> Arc<Vocabulary> {
|
||||
Arc::new(Vocabulary::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vocabulary_default() {
|
||||
let vocab = Vocabulary::default();
|
||||
assert!(vocab.size() > 0);
|
||||
assert_eq!(vocab.get_char(0), Some('a'));
|
||||
assert_eq!(vocab.get_idx('a'), Some(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_greedy_decoder() {
|
||||
let vocab = create_test_vocabulary();
|
||||
let decoder = GreedyDecoder::new(vocab.clone());
|
||||
|
||||
// Mock logits for "hi"
|
||||
let h_idx = vocab.get_idx('h').unwrap();
|
||||
let i_idx = vocab.get_idx('i').unwrap();
|
||||
let blank = vocab.blank_idx();
|
||||
|
||||
let mut logits = vec![
|
||||
vec![0.0; vocab.size() + 1],
|
||||
vec![0.0; vocab.size() + 1],
|
||||
vec![0.0; vocab.size() + 1],
|
||||
];
|
||||
|
||||
logits[0][h_idx] = 10.0;
|
||||
logits[1][blank] = 10.0;
|
||||
logits[2][i_idx] = 10.0;
|
||||
|
||||
let result = decoder.decode(&logits).unwrap();
|
||||
assert_eq!(result, "hi");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_beam_search_decoder() {
|
||||
let vocab = create_test_vocabulary();
|
||||
let decoder = BeamSearchDecoder::new(vocab.clone(), 3);
|
||||
|
||||
assert_eq!(decoder.beam_width(), 3);
|
||||
|
||||
let logits = vec![vec![0.0; vocab.size() + 1]; 5];
|
||||
let result = decoder.decode(&logits);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ctc_decoder() {
|
||||
let vocab = create_test_vocabulary();
|
||||
let decoder = CTCDecoder::new(vocab.clone());
|
||||
|
||||
// Test collapse repeats
|
||||
let a_idx = vocab.get_idx('a').unwrap();
|
||||
let b_idx = vocab.get_idx('b').unwrap();
|
||||
let blank = vocab.blank_idx();
|
||||
|
||||
let indices = vec![a_idx, a_idx, blank, b_idx, b_idx, b_idx];
|
||||
let collapsed = decoder.collapse_repeats(&indices);
|
||||
|
||||
assert_eq!(collapsed, vec![a_idx, b_idx]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_max() {
|
||||
let logits = vec![1.0, 2.0, 3.0, 2.0, 1.0];
|
||||
let max_prob = softmax_max(&logits);
|
||||
assert!(max_prob > 0.0 && max_prob <= 1.0);
|
||||
assert!(max_prob > 0.5); // The max should have high probability
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_logits() {
|
||||
let vocab = create_test_vocabulary();
|
||||
let decoder = GreedyDecoder::new(vocab);
|
||||
|
||||
let empty_logits: Vec<Vec<f32>> = vec![];
|
||||
let result = decoder.decode(&empty_logits).unwrap();
|
||||
assert_eq!(result, "");
|
||||
}
|
||||
}
|
||||
363
vendor/ruvector/examples/scipix/src/ocr/engine.rs
vendored
Normal file
363
vendor/ruvector/examples/scipix/src/ocr/engine.rs
vendored
Normal file
@@ -0,0 +1,363 @@
|
||||
//! OCR Engine Implementation
|
||||
//!
|
||||
//! This module provides the main OcrEngine for orchestrating OCR operations.
|
||||
//! It handles model loading, inference coordination, and result assembly.
|
||||
|
||||
use super::{
|
||||
confidence::aggregate_confidence,
|
||||
decoder::{BeamSearchDecoder, CTCDecoder, Decoder, GreedyDecoder, Vocabulary},
|
||||
inference::{DetectionResult, InferenceEngine, RecognitionResult},
|
||||
models::{ModelHandle, ModelRegistry},
|
||||
Character, DecoderType, OcrError, OcrOptions, OcrResult, RegionType, Result, TextRegion,
|
||||
};
|
||||
use parking_lot::RwLock;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
/// OCR processor trait for custom implementations
|
||||
pub trait OcrProcessor: Send + Sync {
|
||||
/// Process an image and return OCR results
|
||||
fn process(&self, image_data: &[u8], options: &OcrOptions) -> Result<OcrResult>;
|
||||
|
||||
/// Batch process multiple images
|
||||
fn process_batch(&self, images: &[&[u8]], options: &OcrOptions) -> Result<Vec<OcrResult>>;
|
||||
}
|
||||
|
||||
/// Main OCR Engine with thread-safe model management
|
||||
pub struct OcrEngine {
|
||||
/// Model registry for loading and caching models
|
||||
registry: Arc<RwLock<ModelRegistry>>,
|
||||
|
||||
/// Inference engine for running ONNX models
|
||||
inference: Arc<InferenceEngine>,
|
||||
|
||||
/// Default OCR options
|
||||
default_options: OcrOptions,
|
||||
|
||||
/// Vocabulary for decoding
|
||||
vocabulary: Arc<Vocabulary>,
|
||||
|
||||
/// Whether the engine is warmed up
|
||||
warmed_up: Arc<RwLock<bool>>,
|
||||
}
|
||||
|
||||
impl OcrEngine {
|
||||
/// Create a new OCR engine with default models
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use ruvector_scipix::ocr::OcrEngine;
|
||||
/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
|
||||
/// let engine = OcrEngine::new().await?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub async fn new() -> Result<Self> {
|
||||
Self::with_options(OcrOptions::default()).await
|
||||
}
|
||||
|
||||
/// Create a new OCR engine with custom options
|
||||
pub async fn with_options(options: OcrOptions) -> Result<Self> {
|
||||
info!("Initializing OCR engine with options: {:?}", options);
|
||||
|
||||
// Initialize model registry
|
||||
let registry = Arc::new(RwLock::new(ModelRegistry::new()));
|
||||
|
||||
// Load default models (in production, these would be downloaded/cached)
|
||||
debug!("Loading detection model...");
|
||||
let detection_model = registry.write().load_detection_model().await.map_err(|e| {
|
||||
OcrError::ModelLoading(format!("Failed to load detection model: {}", e))
|
||||
})?;
|
||||
|
||||
debug!("Loading recognition model...");
|
||||
let recognition_model = registry
|
||||
.write()
|
||||
.load_recognition_model()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
OcrError::ModelLoading(format!("Failed to load recognition model: {}", e))
|
||||
})?;
|
||||
|
||||
let math_model =
|
||||
if options.enable_math {
|
||||
debug!("Loading math recognition model...");
|
||||
Some(registry.write().load_math_model().await.map_err(|e| {
|
||||
OcrError::ModelLoading(format!("Failed to load math model: {}", e))
|
||||
})?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Create inference engine
|
||||
let inference = Arc::new(InferenceEngine::new(
|
||||
detection_model,
|
||||
recognition_model,
|
||||
math_model,
|
||||
options.use_gpu,
|
||||
)?);
|
||||
|
||||
// Load vocabulary
|
||||
let vocabulary = Arc::new(Vocabulary::default());
|
||||
|
||||
let engine = Self {
|
||||
registry,
|
||||
inference,
|
||||
default_options: options,
|
||||
vocabulary,
|
||||
warmed_up: Arc::new(RwLock::new(false)),
|
||||
};
|
||||
|
||||
info!("OCR engine initialized successfully");
|
||||
Ok(engine)
|
||||
}
|
||||
|
||||
/// Warm up the engine by running a dummy inference
|
||||
///
|
||||
/// This helps reduce latency for the first real inference by initializing
|
||||
/// all ONNX runtime resources.
|
||||
pub async fn warmup(&self) -> Result<()> {
|
||||
if *self.warmed_up.read() {
|
||||
debug!("Engine already warmed up, skipping");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!("Warming up OCR engine...");
|
||||
let start = Instant::now();
|
||||
|
||||
// Create a small dummy image (100x100 black image)
|
||||
let dummy_image = vec![0u8; 100 * 100 * 3];
|
||||
|
||||
// Run a dummy inference
|
||||
let _ = self.recognize(&dummy_image).await;
|
||||
|
||||
*self.warmed_up.write() = true;
|
||||
info!("Engine warmup completed in {:?}", start.elapsed());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Recognize text in an image using default options
|
||||
pub async fn recognize(&self, image_data: &[u8]) -> Result<OcrResult> {
|
||||
self.recognize_with_options(image_data, &self.default_options)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Recognize text in an image with custom options
|
||||
pub async fn recognize_with_options(
|
||||
&self,
|
||||
image_data: &[u8],
|
||||
options: &OcrOptions,
|
||||
) -> Result<OcrResult> {
|
||||
let start = Instant::now();
|
||||
debug!("Starting OCR recognition");
|
||||
|
||||
// Step 1: Run text detection
|
||||
debug!("Running text detection...");
|
||||
let detection_results = self
|
||||
.inference
|
||||
.run_detection(image_data, options.detection_threshold)
|
||||
.await?;
|
||||
|
||||
debug!("Detected {} regions", detection_results.len());
|
||||
|
||||
if detection_results.is_empty() {
|
||||
warn!("No text regions detected");
|
||||
return Ok(OcrResult {
|
||||
text: String::new(),
|
||||
confidence: 0.0,
|
||||
regions: vec![],
|
||||
has_math: false,
|
||||
processing_time_ms: start.elapsed().as_millis() as u64,
|
||||
});
|
||||
}
|
||||
|
||||
// Step 2: Run recognition on each detected region
|
||||
debug!("Running text recognition...");
|
||||
let mut text_regions = Vec::new();
|
||||
let mut has_math = false;
|
||||
|
||||
for detection in detection_results {
|
||||
// Determine region type
|
||||
let region_type = if options.enable_math && detection.is_math_likely {
|
||||
has_math = true;
|
||||
RegionType::Math
|
||||
} else {
|
||||
RegionType::Text
|
||||
};
|
||||
|
||||
// Run appropriate recognition
|
||||
let recognition = if region_type == RegionType::Math {
|
||||
self.inference
|
||||
.run_math_recognition(&detection.region_image, options)
|
||||
.await?
|
||||
} else {
|
||||
self.inference
|
||||
.run_recognition(&detection.region_image, options)
|
||||
.await?
|
||||
};
|
||||
|
||||
// Decode the recognition output
|
||||
let decoded_text = self.decode_output(&recognition, options)?;
|
||||
|
||||
// Calculate confidence
|
||||
let confidence = aggregate_confidence(&recognition.character_confidences);
|
||||
|
||||
// Filter by confidence threshold
|
||||
if confidence < options.recognition_threshold {
|
||||
debug!(
|
||||
"Skipping region with low confidence: {:.2} < {:.2}",
|
||||
confidence, options.recognition_threshold
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Build character list
|
||||
let characters = decoded_text
|
||||
.chars()
|
||||
.zip(recognition.character_confidences.iter())
|
||||
.map(|(ch, &conf)| Character {
|
||||
char: ch,
|
||||
confidence: conf,
|
||||
bbox: None, // Could be populated if available from model
|
||||
})
|
||||
.collect();
|
||||
|
||||
text_regions.push(TextRegion {
|
||||
bbox: detection.bbox,
|
||||
text: decoded_text,
|
||||
confidence,
|
||||
region_type,
|
||||
characters,
|
||||
});
|
||||
}
|
||||
|
||||
// Step 3: Combine results
|
||||
let combined_text = text_regions
|
||||
.iter()
|
||||
.map(|r| r.text.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ");
|
||||
|
||||
let overall_confidence = if text_regions.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
text_regions.iter().map(|r| r.confidence).sum::<f32>() / text_regions.len() as f32
|
||||
};
|
||||
|
||||
let processing_time_ms = start.elapsed().as_millis() as u64;
|
||||
|
||||
debug!(
|
||||
"OCR completed in {}ms, recognized {} regions",
|
||||
processing_time_ms,
|
||||
text_regions.len()
|
||||
);
|
||||
|
||||
Ok(OcrResult {
|
||||
text: combined_text,
|
||||
confidence: overall_confidence,
|
||||
regions: text_regions,
|
||||
has_math,
|
||||
processing_time_ms,
|
||||
})
|
||||
}
|
||||
|
||||
/// Batch process multiple images
|
||||
pub async fn recognize_batch(
|
||||
&self,
|
||||
images: &[&[u8]],
|
||||
options: &OcrOptions,
|
||||
) -> Result<Vec<OcrResult>> {
|
||||
info!("Processing batch of {} images", images.len());
|
||||
let start = Instant::now();
|
||||
|
||||
// Process images in parallel using rayon
|
||||
let results: Result<Vec<OcrResult>> = images
|
||||
.iter()
|
||||
.map(|image_data| {
|
||||
// Note: In a real async implementation, we'd use tokio::spawn
|
||||
// For now, we'll use blocking since we're in a sync context
|
||||
futures::executor::block_on(self.recognize_with_options(image_data, options))
|
||||
})
|
||||
.collect();
|
||||
|
||||
info!("Batch processing completed in {:?}", start.elapsed());
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// Decode recognition output using the selected decoder
|
||||
fn decode_output(
|
||||
&self,
|
||||
recognition: &RecognitionResult,
|
||||
options: &OcrOptions,
|
||||
) -> Result<String> {
|
||||
debug!("Decoding output with {:?} decoder", options.decoder_type);
|
||||
|
||||
let decoded = match options.decoder_type {
|
||||
DecoderType::BeamSearch => {
|
||||
let decoder = BeamSearchDecoder::new(self.vocabulary.clone(), options.beam_width);
|
||||
decoder.decode(&recognition.logits)?
|
||||
}
|
||||
DecoderType::Greedy => {
|
||||
let decoder = GreedyDecoder::new(self.vocabulary.clone());
|
||||
decoder.decode(&recognition.logits)?
|
||||
}
|
||||
DecoderType::CTC => {
|
||||
let decoder = CTCDecoder::new(self.vocabulary.clone());
|
||||
decoder.decode(&recognition.logits)?
|
||||
}
|
||||
};
|
||||
|
||||
Ok(decoded)
|
||||
}
|
||||
|
||||
/// Get the current model registry
|
||||
pub fn registry(&self) -> Arc<RwLock<ModelRegistry>> {
|
||||
Arc::clone(&self.registry)
|
||||
}
|
||||
|
||||
/// Get the default options
|
||||
pub fn default_options(&self) -> &OcrOptions {
|
||||
&self.default_options
|
||||
}
|
||||
|
||||
/// Check if engine is warmed up
|
||||
pub fn is_warmed_up(&self) -> bool {
|
||||
*self.warmed_up.read()
|
||||
}
|
||||
}
|
||||
|
||||
impl OcrProcessor for OcrEngine {
|
||||
fn process(&self, image_data: &[u8], options: &OcrOptions) -> Result<OcrResult> {
|
||||
// Blocking wrapper for async method
|
||||
futures::executor::block_on(self.recognize_with_options(image_data, options))
|
||||
}
|
||||
|
||||
fn process_batch(&self, images: &[&[u8]], options: &OcrOptions) -> Result<Vec<OcrResult>> {
|
||||
// Blocking wrapper for async method
|
||||
futures::executor::block_on(self.recognize_batch(images, options))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_decoder_selection() {
|
||||
let options = OcrOptions {
|
||||
decoder_type: DecoderType::BeamSearch,
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(options.decoder_type, DecoderType::BeamSearch);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_warmup_flag() {
|
||||
let flag = Arc::new(RwLock::new(false));
|
||||
assert!(!*flag.read());
|
||||
*flag.write() = true;
|
||||
assert!(*flag.read());
|
||||
}
|
||||
}
|
||||
790
vendor/ruvector/examples/scipix/src/ocr/inference.rs
vendored
Normal file
790
vendor/ruvector/examples/scipix/src/ocr/inference.rs
vendored
Normal file
@@ -0,0 +1,790 @@
|
||||
//! ONNX Inference Module
|
||||
//!
|
||||
//! This module handles ONNX inference operations for text detection,
|
||||
//! character recognition, and mathematical expression recognition.
|
||||
//!
|
||||
//! # Model Requirements
|
||||
//!
|
||||
//! This module requires ONNX models to be available in the configured model directory.
|
||||
//! Without models, all inference operations will return errors.
|
||||
//!
|
||||
//! To use this module:
|
||||
//! 1. Download compatible ONNX models (PaddleOCR, TrOCR, or similar)
|
||||
//! 2. Place them in the models directory
|
||||
//! 3. Enable the `ocr` feature flag
|
||||
|
||||
use super::{models::ModelHandle, OcrError, OcrOptions, Result};
|
||||
use image::{DynamicImage, GenericImageView};
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
use ndarray::Array4;
|
||||
#[cfg(feature = "ocr")]
|
||||
use ort::value::Tensor;
|
||||
|
||||
/// Result from text detection
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DetectionResult {
|
||||
/// Bounding box [x, y, width, height]
|
||||
pub bbox: [f32; 4],
|
||||
/// Detection confidence
|
||||
pub confidence: f32,
|
||||
/// Cropped image region
|
||||
pub region_image: Vec<u8>,
|
||||
/// Whether this region likely contains math
|
||||
pub is_math_likely: bool,
|
||||
}
|
||||
|
||||
/// Result from text/math recognition
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RecognitionResult {
|
||||
/// Logits output from the model [sequence_length, vocab_size]
|
||||
pub logits: Vec<Vec<f32>>,
|
||||
/// Character-level confidence scores
|
||||
pub character_confidences: Vec<f32>,
|
||||
/// Raw output tensor (for debugging)
|
||||
pub raw_output: Option<Vec<f32>>,
|
||||
}
|
||||
|
||||
/// Inference engine for running ONNX models
|
||||
///
|
||||
/// IMPORTANT: This engine requires ONNX models to be loaded.
|
||||
/// All methods will return errors if models are not properly initialized.
|
||||
pub struct InferenceEngine {
|
||||
/// Detection model
|
||||
detection_model: Arc<ModelHandle>,
|
||||
/// Recognition model
|
||||
recognition_model: Arc<ModelHandle>,
|
||||
/// Math recognition model (optional)
|
||||
math_model: Option<Arc<ModelHandle>>,
|
||||
/// Whether to use GPU acceleration
|
||||
use_gpu: bool,
|
||||
/// Whether models are actually loaded (vs placeholder handles)
|
||||
models_loaded: bool,
|
||||
}
|
||||
|
||||
impl InferenceEngine {
|
||||
/// Create a new inference engine
|
||||
pub fn new(
|
||||
detection_model: Arc<ModelHandle>,
|
||||
recognition_model: Arc<ModelHandle>,
|
||||
math_model: Option<Arc<ModelHandle>>,
|
||||
use_gpu: bool,
|
||||
) -> Result<Self> {
|
||||
// Check if models are actually loaded with ONNX sessions
|
||||
let detection_loaded = detection_model.is_loaded();
|
||||
let recognition_loaded = recognition_model.is_loaded();
|
||||
let models_loaded = detection_loaded && recognition_loaded;
|
||||
|
||||
if !models_loaded {
|
||||
warn!(
|
||||
"ONNX models not fully loaded. Detection: {}, Recognition: {}",
|
||||
detection_loaded, recognition_loaded
|
||||
);
|
||||
warn!("OCR inference will fail until models are properly configured.");
|
||||
} else {
|
||||
info!(
|
||||
"Inference engine initialized with loaded models (GPU: {})",
|
||||
if use_gpu { "enabled" } else { "disabled" }
|
||||
);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
detection_model,
|
||||
recognition_model,
|
||||
math_model,
|
||||
use_gpu,
|
||||
models_loaded,
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if the inference engine is ready for use
|
||||
pub fn is_ready(&self) -> bool {
|
||||
self.models_loaded
|
||||
}
|
||||
|
||||
/// Run text detection on an image
|
||||
pub async fn run_detection(
|
||||
&self,
|
||||
image_data: &[u8],
|
||||
threshold: f32,
|
||||
) -> Result<Vec<DetectionResult>> {
|
||||
if !self.models_loaded {
|
||||
return Err(OcrError::ModelLoading(
|
||||
"ONNX models not loaded. Please download and configure OCR models before use. \
|
||||
See examples/scipix/docs/MODEL_SETUP.md for instructions."
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
debug!("Running text detection (threshold: {})", threshold);
|
||||
let input_tensor = self.preprocess_image_for_detection(image_data)?;
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
{
|
||||
let detections = self
|
||||
.run_onnx_detection(&input_tensor, threshold, image_data)
|
||||
.await?;
|
||||
debug!("Detected {} regions", detections.len());
|
||||
return Ok(detections);
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "ocr"))]
|
||||
{
|
||||
Err(OcrError::Inference(
|
||||
"OCR feature not enabled. Rebuild with `--features ocr` to enable ONNX inference."
|
||||
.to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Run text recognition on a region image
|
||||
pub async fn run_recognition(
|
||||
&self,
|
||||
region_image: &[u8],
|
||||
options: &OcrOptions,
|
||||
) -> Result<RecognitionResult> {
|
||||
if !self.models_loaded {
|
||||
return Err(OcrError::ModelLoading(
|
||||
"ONNX models not loaded. Please download and configure OCR models before use."
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
debug!("Running text recognition");
|
||||
let input_tensor = self.preprocess_image_for_recognition(region_image)?;
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
{
|
||||
let result = self.run_onnx_recognition(&input_tensor, options).await?;
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "ocr"))]
|
||||
{
|
||||
Err(OcrError::Inference(
|
||||
"OCR feature not enabled. Rebuild with `--features ocr` to enable ONNX inference."
|
||||
.to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Run math recognition on a region image
|
||||
pub async fn run_math_recognition(
|
||||
&self,
|
||||
region_image: &[u8],
|
||||
options: &OcrOptions,
|
||||
) -> Result<RecognitionResult> {
|
||||
if !self.models_loaded {
|
||||
return Err(OcrError::ModelLoading(
|
||||
"ONNX models not loaded. Please download and configure OCR models before use."
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
debug!("Running math recognition");
|
||||
|
||||
if self.math_model.is_none() || !self.math_model.as_ref().unwrap().is_loaded() {
|
||||
warn!("Math model not loaded, falling back to text recognition");
|
||||
return self.run_recognition(region_image, options).await;
|
||||
}
|
||||
|
||||
let input_tensor = self.preprocess_image_for_math(region_image)?;
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
{
|
||||
let result = self
|
||||
.run_onnx_math_recognition(&input_tensor, options)
|
||||
.await?;
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "ocr"))]
|
||||
{
|
||||
Err(OcrError::Inference(
|
||||
"OCR feature not enabled. Rebuild with `--features ocr` to enable ONNX inference."
|
||||
.to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Preprocess image for detection model
|
||||
fn preprocess_image_for_detection(&self, image_data: &[u8]) -> Result<Vec<f32>> {
|
||||
let img = image::load_from_memory(image_data)
|
||||
.map_err(|e| OcrError::ImageProcessing(format!("Failed to decode image: {}", e)))?;
|
||||
|
||||
let input_shape = self.detection_model.input_shape();
|
||||
let (_, _, height, width) = (
|
||||
input_shape[0],
|
||||
input_shape[1],
|
||||
input_shape[2],
|
||||
input_shape[3],
|
||||
);
|
||||
|
||||
let resized = img.resize_exact(
|
||||
width as u32,
|
||||
height as u32,
|
||||
image::imageops::FilterType::Lanczos3,
|
||||
);
|
||||
|
||||
let rgb = resized.to_rgb8();
|
||||
let mut tensor = Vec::with_capacity(3 * height * width);
|
||||
|
||||
// Convert to NCHW format with normalization
|
||||
for c in 0..3 {
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
let pixel = rgb.get_pixel(x as u32, y as u32);
|
||||
tensor.push(pixel[c] as f32 / 255.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
/// Preprocess image for recognition model
|
||||
fn preprocess_image_for_recognition(&self, image_data: &[u8]) -> Result<Vec<f32>> {
|
||||
let img = image::load_from_memory(image_data)
|
||||
.map_err(|e| OcrError::ImageProcessing(format!("Failed to decode image: {}", e)))?;
|
||||
|
||||
let input_shape = self.recognition_model.input_shape();
|
||||
let (_, channels, height, width) = (
|
||||
input_shape[0],
|
||||
input_shape[1],
|
||||
input_shape[2],
|
||||
input_shape[3],
|
||||
);
|
||||
|
||||
let resized = img.resize_exact(
|
||||
width as u32,
|
||||
height as u32,
|
||||
image::imageops::FilterType::Lanczos3,
|
||||
);
|
||||
|
||||
let mut tensor = Vec::with_capacity(channels * height * width);
|
||||
|
||||
if channels == 1 {
|
||||
let gray = resized.to_luma8();
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
let pixel = gray.get_pixel(x as u32, y as u32);
|
||||
tensor.push((pixel[0] as f32 / 127.5) - 1.0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let rgb = resized.to_rgb8();
|
||||
for c in 0..3 {
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
let pixel = rgb.get_pixel(x as u32, y as u32);
|
||||
tensor.push((pixel[c] as f32 / 127.5) - 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
/// Preprocess image for math recognition model
|
||||
fn preprocess_image_for_math(&self, image_data: &[u8]) -> Result<Vec<f32>> {
|
||||
let math_model = self
|
||||
.math_model
|
||||
.as_ref()
|
||||
.ok_or_else(|| OcrError::Inference("Math model not loaded".to_string()))?;
|
||||
|
||||
let img = image::load_from_memory(image_data)
|
||||
.map_err(|e| OcrError::ImageProcessing(format!("Failed to decode image: {}", e)))?;
|
||||
|
||||
let input_shape = math_model.input_shape();
|
||||
let (_, channels, height, width) = (
|
||||
input_shape[0],
|
||||
input_shape[1],
|
||||
input_shape[2],
|
||||
input_shape[3],
|
||||
);
|
||||
|
||||
let resized = img.resize_exact(
|
||||
width as u32,
|
||||
height as u32,
|
||||
image::imageops::FilterType::Lanczos3,
|
||||
);
|
||||
|
||||
let mut tensor = Vec::with_capacity(channels * height * width);
|
||||
|
||||
if channels == 1 {
|
||||
let gray = resized.to_luma8();
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
let pixel = gray.get_pixel(x as u32, y as u32);
|
||||
tensor.push((pixel[0] as f32 / 127.5) - 1.0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let rgb = resized.to_rgb8();
|
||||
for c in 0..channels {
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
let pixel = rgb.get_pixel(x as u32, y as u32);
|
||||
tensor.push((pixel[c] as f32 / 127.5) - 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
/// ONNX detection inference (requires `ocr` feature)
|
||||
#[cfg(feature = "ocr")]
|
||||
async fn run_onnx_detection(
|
||||
&self,
|
||||
input_tensor: &[f32],
|
||||
threshold: f32,
|
||||
original_image: &[u8],
|
||||
) -> Result<Vec<DetectionResult>> {
|
||||
let session_arc = self.detection_model.session().ok_or_else(|| {
|
||||
OcrError::OnnxRuntime("Detection model session not loaded".to_string())
|
||||
})?;
|
||||
let mut session = session_arc.lock();
|
||||
|
||||
let input_shape = self.detection_model.input_shape();
|
||||
let shape: Vec<usize> = input_shape.to_vec();
|
||||
|
||||
// Create tensor from input data
|
||||
let input_array = Array4::from_shape_vec(
|
||||
(shape[0], shape[1], shape[2], shape[3]),
|
||||
input_tensor.to_vec(),
|
||||
)
|
||||
.map_err(|e| OcrError::Inference(format!("Failed to create input tensor: {}", e)))?;
|
||||
|
||||
// Convert to dynamic-dimension view and create ORT tensor
|
||||
let input_dyn = input_array.into_dyn();
|
||||
let input_tensor = Tensor::from_array(input_dyn)
|
||||
.map_err(|e| OcrError::OnnxRuntime(format!("Failed to create ORT tensor: {}", e)))?;
|
||||
|
||||
// Run inference
|
||||
let outputs = session
|
||||
.run(ort::inputs![input_tensor])
|
||||
.map_err(|e| OcrError::OnnxRuntime(format!("Inference failed: {}", e)))?;
|
||||
|
||||
let output_tensor = outputs
|
||||
.iter()
|
||||
.next()
|
||||
.map(|(_, v)| v)
|
||||
.ok_or_else(|| OcrError::OnnxRuntime("No output tensor found".to_string()))?;
|
||||
|
||||
let (_, raw_data) = output_tensor
|
||||
.try_extract_tensor::<f32>()
|
||||
.map_err(|e| OcrError::OnnxRuntime(format!("Failed to extract output: {}", e)))?;
|
||||
let output_data: Vec<f32> = raw_data.to_vec();
|
||||
|
||||
let original_img = image::load_from_memory(original_image)
|
||||
.map_err(|e| OcrError::ImageProcessing(format!("Failed to decode image: {}", e)))?;
|
||||
|
||||
let detections = self.parse_detection_output(&output_data, threshold, &original_img)?;
|
||||
Ok(detections)
|
||||
}
|
||||
|
||||
/// Parse detection model output
|
||||
#[cfg(feature = "ocr")]
|
||||
fn parse_detection_output(
|
||||
&self,
|
||||
output: &[f32],
|
||||
threshold: f32,
|
||||
original_img: &DynamicImage,
|
||||
) -> Result<Vec<DetectionResult>> {
|
||||
let mut results = Vec::new();
|
||||
let output_shape = self.detection_model.output_shape();
|
||||
|
||||
if output_shape.len() >= 2 {
|
||||
let num_detections = output_shape[1];
|
||||
let detection_size = if output_shape.len() >= 3 {
|
||||
output_shape[2]
|
||||
} else {
|
||||
85
|
||||
};
|
||||
|
||||
for i in 0..num_detections {
|
||||
let base_idx = i * detection_size;
|
||||
if base_idx + 5 > output.len() {
|
||||
break;
|
||||
}
|
||||
|
||||
let confidence = output[base_idx + 4];
|
||||
if confidence < threshold {
|
||||
continue;
|
||||
}
|
||||
|
||||
let cx = output[base_idx];
|
||||
let cy = output[base_idx + 1];
|
||||
let w = output[base_idx + 2];
|
||||
let h = output[base_idx + 3];
|
||||
|
||||
let img_width = original_img.width() as f32;
|
||||
let img_height = original_img.height() as f32;
|
||||
|
||||
let x = ((cx - w / 2.0) * img_width).max(0.0);
|
||||
let y = ((cy - h / 2.0) * img_height).max(0.0);
|
||||
let width = (w * img_width).min(img_width - x);
|
||||
let height = (h * img_height).min(img_height - y);
|
||||
|
||||
if width <= 0.0 || height <= 0.0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let cropped =
|
||||
original_img.crop_imm(x as u32, y as u32, width as u32, height as u32);
|
||||
|
||||
let mut region_bytes = Vec::new();
|
||||
cropped
|
||||
.write_to(
|
||||
&mut std::io::Cursor::new(&mut region_bytes),
|
||||
image::ImageFormat::Png,
|
||||
)
|
||||
.map_err(|e| {
|
||||
OcrError::ImageProcessing(format!("Failed to encode region: {}", e))
|
||||
})?;
|
||||
|
||||
let aspect_ratio = width / height;
|
||||
let is_math_likely = aspect_ratio > 2.0 || aspect_ratio < 0.5;
|
||||
|
||||
results.push(DetectionResult {
|
||||
bbox: [x, y, width, height],
|
||||
confidence,
|
||||
region_image: region_bytes,
|
||||
is_math_likely,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// ONNX recognition inference (requires `ocr` feature)
|
||||
#[cfg(feature = "ocr")]
|
||||
async fn run_onnx_recognition(
|
||||
&self,
|
||||
input_tensor: &[f32],
|
||||
_options: &OcrOptions,
|
||||
) -> Result<RecognitionResult> {
|
||||
let session_arc = self.recognition_model.session().ok_or_else(|| {
|
||||
OcrError::OnnxRuntime("Recognition model session not loaded".to_string())
|
||||
})?;
|
||||
let mut session = session_arc.lock();
|
||||
|
||||
let input_shape = self.recognition_model.input_shape();
|
||||
let shape: Vec<usize> = input_shape.to_vec();
|
||||
|
||||
let input_array = Array4::from_shape_vec(
|
||||
(shape[0], shape[1], shape[2], shape[3]),
|
||||
input_tensor.to_vec(),
|
||||
)
|
||||
.map_err(|e| OcrError::Inference(format!("Failed to create input tensor: {}", e)))?;
|
||||
|
||||
let input_dyn = input_array.into_dyn();
|
||||
let input_ort = Tensor::from_array(input_dyn)
|
||||
.map_err(|e| OcrError::OnnxRuntime(format!("Failed to create ORT tensor: {}", e)))?;
|
||||
|
||||
let outputs = session
|
||||
.run(ort::inputs![input_ort])
|
||||
.map_err(|e| OcrError::OnnxRuntime(format!("Recognition inference failed: {}", e)))?;
|
||||
|
||||
let output_tensor = outputs
|
||||
.iter()
|
||||
.next()
|
||||
.map(|(_, v)| v)
|
||||
.ok_or_else(|| OcrError::OnnxRuntime("No output tensor found".to_string()))?;
|
||||
|
||||
let (_, raw_data) = output_tensor
|
||||
.try_extract_tensor::<f32>()
|
||||
.map_err(|e| OcrError::OnnxRuntime(format!("Failed to extract output: {}", e)))?;
|
||||
let output_data: Vec<f32> = raw_data.to_vec();
|
||||
|
||||
let output_shape = self.recognition_model.output_shape();
|
||||
let seq_len = output_shape.get(1).copied().unwrap_or(26);
|
||||
let vocab_size = output_shape.get(2).copied().unwrap_or(37);
|
||||
|
||||
let mut logits = Vec::new();
|
||||
let mut character_confidences = Vec::new();
|
||||
|
||||
for i in 0..seq_len {
|
||||
let start_idx = i * vocab_size;
|
||||
let end_idx = start_idx + vocab_size;
|
||||
|
||||
if end_idx <= output_data.len() {
|
||||
let step_logits: Vec<f32> = output_data[start_idx..end_idx].to_vec();
|
||||
|
||||
let max_logit = step_logits
|
||||
.iter()
|
||||
.cloned()
|
||||
.fold(f32::NEG_INFINITY, f32::max);
|
||||
let exp_sum: f32 = step_logits.iter().map(|&x| (x - max_logit).exp()).sum();
|
||||
|
||||
let softmax: Vec<f32> = step_logits
|
||||
.iter()
|
||||
.map(|&x| (x - max_logit).exp() / exp_sum)
|
||||
.collect();
|
||||
|
||||
let max_confidence = softmax.iter().cloned().fold(0.0f32, f32::max);
|
||||
character_confidences.push(max_confidence);
|
||||
logits.push(step_logits);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(RecognitionResult {
|
||||
logits,
|
||||
character_confidences,
|
||||
raw_output: Some(output_data),
|
||||
})
|
||||
}
|
||||
|
||||
/// ONNX math recognition inference (requires `ocr` feature)
|
||||
#[cfg(feature = "ocr")]
|
||||
async fn run_onnx_math_recognition(
|
||||
&self,
|
||||
input_tensor: &[f32],
|
||||
_options: &OcrOptions,
|
||||
) -> Result<RecognitionResult> {
|
||||
let math_model = self
|
||||
.math_model
|
||||
.as_ref()
|
||||
.ok_or_else(|| OcrError::Inference("Math model not loaded".to_string()))?;
|
||||
|
||||
let session_arc = math_model
|
||||
.session()
|
||||
.ok_or_else(|| OcrError::OnnxRuntime("Math model session not loaded".to_string()))?;
|
||||
let mut session = session_arc.lock();
|
||||
|
||||
let input_shape = math_model.input_shape();
|
||||
let shape: Vec<usize> = input_shape.to_vec();
|
||||
|
||||
let input_array = Array4::from_shape_vec(
|
||||
(shape[0], shape[1], shape[2], shape[3]),
|
||||
input_tensor.to_vec(),
|
||||
)
|
||||
.map_err(|e| OcrError::Inference(format!("Failed to create input tensor: {}", e)))?;
|
||||
|
||||
let input_dyn = input_array.into_dyn();
|
||||
let input_ort = Tensor::from_array(input_dyn)
|
||||
.map_err(|e| OcrError::OnnxRuntime(format!("Failed to create ORT tensor: {}", e)))?;
|
||||
|
||||
let outputs = session.run(ort::inputs![input_ort]).map_err(|e| {
|
||||
OcrError::OnnxRuntime(format!("Math recognition inference failed: {}", e))
|
||||
})?;
|
||||
|
||||
let output_tensor = outputs
|
||||
.iter()
|
||||
.next()
|
||||
.map(|(_, v)| v)
|
||||
.ok_or_else(|| OcrError::OnnxRuntime("No output tensor found".to_string()))?;
|
||||
|
||||
let (_, raw_data) = output_tensor
|
||||
.try_extract_tensor::<f32>()
|
||||
.map_err(|e| OcrError::OnnxRuntime(format!("Failed to extract output: {}", e)))?;
|
||||
let output_data: Vec<f32> = raw_data.to_vec();
|
||||
|
||||
let output_shape = math_model.output_shape();
|
||||
let seq_len = output_shape.get(1).copied().unwrap_or(50);
|
||||
let vocab_size = output_shape.get(2).copied().unwrap_or(512);
|
||||
|
||||
let mut logits = Vec::new();
|
||||
let mut character_confidences = Vec::new();
|
||||
|
||||
for i in 0..seq_len {
|
||||
let start_idx = i * vocab_size;
|
||||
let end_idx = start_idx + vocab_size;
|
||||
|
||||
if end_idx <= output_data.len() {
|
||||
let step_logits: Vec<f32> = output_data[start_idx..end_idx].to_vec();
|
||||
|
||||
let max_logit = step_logits
|
||||
.iter()
|
||||
.cloned()
|
||||
.fold(f32::NEG_INFINITY, f32::max);
|
||||
let exp_sum: f32 = step_logits.iter().map(|&x| (x - max_logit).exp()).sum();
|
||||
|
||||
let softmax: Vec<f32> = step_logits
|
||||
.iter()
|
||||
.map(|&x| (x - max_logit).exp() / exp_sum)
|
||||
.collect();
|
||||
|
||||
let max_confidence = softmax.iter().cloned().fold(0.0f32, f32::max);
|
||||
character_confidences.push(max_confidence);
|
||||
logits.push(step_logits);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(RecognitionResult {
|
||||
logits,
|
||||
character_confidences,
|
||||
raw_output: Some(output_data),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get detection model
|
||||
pub fn detection_model(&self) -> &ModelHandle {
|
||||
&self.detection_model
|
||||
}
|
||||
|
||||
/// Get recognition model
|
||||
pub fn recognition_model(&self) -> &ModelHandle {
|
||||
&self.recognition_model
|
||||
}
|
||||
|
||||
/// Get math model if available
|
||||
pub fn math_model(&self) -> Option<&ModelHandle> {
|
||||
self.math_model.as_ref().map(|m| m.as_ref())
|
||||
}
|
||||
|
||||
/// Check if GPU acceleration is enabled
|
||||
pub fn is_gpu_enabled(&self) -> bool {
|
||||
self.use_gpu
|
||||
}
|
||||
}
|
||||
|
||||
/// Batch inference optimization
|
||||
impl InferenceEngine {
|
||||
/// Run batch detection on multiple images
|
||||
pub async fn run_batch_detection(
|
||||
&self,
|
||||
images: &[&[u8]],
|
||||
threshold: f32,
|
||||
) -> Result<Vec<Vec<DetectionResult>>> {
|
||||
if !self.models_loaded {
|
||||
return Err(OcrError::ModelLoading(
|
||||
"ONNX models not loaded. Cannot run batch detection.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
debug!("Running batch detection on {} images", images.len());
|
||||
|
||||
let mut results = Vec::new();
|
||||
for image in images {
|
||||
let detections = self.run_detection(image, threshold).await?;
|
||||
results.push(detections);
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Run batch recognition on multiple regions
|
||||
pub async fn run_batch_recognition(
|
||||
&self,
|
||||
regions: &[&[u8]],
|
||||
options: &OcrOptions,
|
||||
) -> Result<Vec<RecognitionResult>> {
|
||||
if !self.models_loaded {
|
||||
return Err(OcrError::ModelLoading(
|
||||
"ONNX models not loaded. Cannot run batch recognition.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
debug!("Running batch recognition on {} regions", regions.len());
|
||||
|
||||
let mut results = Vec::new();
|
||||
for region in regions {
|
||||
let result = self.run_recognition(region, options).await?;
|
||||
results.push(result);
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::ocr::models::{ModelMetadata, ModelType};
|
||||
use std::path::PathBuf;
|
||||
|
||||
fn create_test_model(model_type: ModelType, path: PathBuf) -> Arc<ModelHandle> {
|
||||
let metadata = ModelMetadata {
|
||||
name: format!("{:?} Model", model_type),
|
||||
version: "1.0.0".to_string(),
|
||||
input_shape: vec![1, 3, 640, 640],
|
||||
output_shape: vec![1, 100, 85],
|
||||
input_dtype: "float32".to_string(),
|
||||
file_size: 1000,
|
||||
checksum: None,
|
||||
};
|
||||
|
||||
Arc::new(ModelHandle::new(model_type, path, metadata).unwrap())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inference_engine_creation_without_models() {
|
||||
let detection = create_test_model(
|
||||
ModelType::Detection,
|
||||
PathBuf::from("/nonexistent/model.onnx"),
|
||||
);
|
||||
let recognition = create_test_model(
|
||||
ModelType::Recognition,
|
||||
PathBuf::from("/nonexistent/model.onnx"),
|
||||
);
|
||||
|
||||
let engine = InferenceEngine::new(detection, recognition, None, false).unwrap();
|
||||
assert!(!engine.is_ready());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_detection_fails_without_models() {
|
||||
let detection = create_test_model(
|
||||
ModelType::Detection,
|
||||
PathBuf::from("/nonexistent/model.onnx"),
|
||||
);
|
||||
let recognition = create_test_model(
|
||||
ModelType::Recognition,
|
||||
PathBuf::from("/nonexistent/model.onnx"),
|
||||
);
|
||||
let engine = InferenceEngine::new(detection, recognition, None, false).unwrap();
|
||||
|
||||
let png_data = create_test_png();
|
||||
let result = engine.run_detection(&png_data, 0.5).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result.unwrap_err(), OcrError::ModelLoading(_)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_recognition_fails_without_models() {
|
||||
let detection = create_test_model(
|
||||
ModelType::Detection,
|
||||
PathBuf::from("/nonexistent/model.onnx"),
|
||||
);
|
||||
let recognition = create_test_model(
|
||||
ModelType::Recognition,
|
||||
PathBuf::from("/nonexistent/model.onnx"),
|
||||
);
|
||||
let engine = InferenceEngine::new(detection, recognition, None, false).unwrap();
|
||||
|
||||
let png_data = create_test_png();
|
||||
let options = OcrOptions::default();
|
||||
let result = engine.run_recognition(&png_data, &options).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result.unwrap_err(), OcrError::ModelLoading(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_ready_reflects_model_state() {
|
||||
let detection = create_test_model(ModelType::Detection, PathBuf::from("/fake/path"));
|
||||
let recognition = create_test_model(ModelType::Recognition, PathBuf::from("/fake/path"));
|
||||
let engine = InferenceEngine::new(detection, recognition, None, false).unwrap();
|
||||
assert!(!engine.is_ready());
|
||||
}
|
||||
|
||||
fn create_test_png() -> Vec<u8> {
|
||||
use image::{ImageBuffer, RgbImage};
|
||||
let img: RgbImage = ImageBuffer::from_fn(10, 10, |_, _| image::Rgb([255, 255, 255]));
|
||||
let mut bytes: Vec<u8> = Vec::new();
|
||||
img.write_to(
|
||||
&mut std::io::Cursor::new(&mut bytes),
|
||||
image::ImageFormat::Png,
|
||||
)
|
||||
.unwrap();
|
||||
bytes
|
||||
}
|
||||
}
|
||||
235
vendor/ruvector/examples/scipix/src/ocr/mod.rs
vendored
Normal file
235
vendor/ruvector/examples/scipix/src/ocr/mod.rs
vendored
Normal file
@@ -0,0 +1,235 @@
|
||||
//! OCR Engine Module
|
||||
//!
|
||||
//! This module provides optical character recognition capabilities for the ruvector-scipix system.
|
||||
//! It supports text detection, character recognition, and mathematical expression recognition using
|
||||
//! ONNX models for high-performance inference.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! The OCR module is organized into several submodules:
|
||||
//! - `engine`: Main OcrEngine for orchestrating OCR operations
|
||||
//! - `models`: Model management, loading, and caching
|
||||
//! - `inference`: ONNX inference operations for detection and recognition
|
||||
//! - `decoder`: Output decoding strategies (beam search, greedy, CTC)
|
||||
//! - `confidence`: Confidence scoring and calibration
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```no_run
|
||||
//! use ruvector_scipix::ocr::{OcrEngine, OcrOptions};
|
||||
//!
|
||||
//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
|
||||
//! // Initialize the OCR engine
|
||||
//! let engine = OcrEngine::new().await?;
|
||||
//!
|
||||
//! // Load an image
|
||||
//! let image_data = std::fs::read("math_formula.png")?;
|
||||
//!
|
||||
//! // Perform OCR
|
||||
//! let result = engine.recognize(&image_data).await?;
|
||||
//!
|
||||
//! println!("Recognized text: {}", result.text);
|
||||
//! println!("Confidence: {:.2}%", result.confidence * 100.0);
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
|
||||
// Submodules
|
||||
mod confidence;
|
||||
mod decoder;
|
||||
mod engine;
|
||||
mod inference;
|
||||
mod models;
|
||||
|
||||
// Public exports
|
||||
pub use confidence::{aggregate_confidence, calculate_confidence, ConfidenceCalibrator};
|
||||
pub use decoder::{BeamSearchDecoder, CTCDecoder, Decoder, GreedyDecoder, Vocabulary};
|
||||
pub use engine::{OcrEngine, OcrProcessor};
|
||||
pub use inference::{DetectionResult, InferenceEngine, RecognitionResult};
|
||||
pub use models::{ModelHandle, ModelRegistry};
|
||||
|
||||
/// OCR processing options
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OcrOptions {
|
||||
/// Detection threshold for text regions (0.0-1.0)
|
||||
pub detection_threshold: f32,
|
||||
|
||||
/// Recognition confidence threshold (0.0-1.0)
|
||||
pub recognition_threshold: f32,
|
||||
|
||||
/// Enable mathematical expression recognition
|
||||
pub enable_math: bool,
|
||||
|
||||
/// Decoder type to use
|
||||
pub decoder_type: DecoderType,
|
||||
|
||||
/// Beam width for beam search decoder
|
||||
pub beam_width: usize,
|
||||
|
||||
/// Maximum batch size for inference
|
||||
pub batch_size: usize,
|
||||
|
||||
/// Enable GPU acceleration if available
|
||||
pub use_gpu: bool,
|
||||
|
||||
/// Language hints for recognition
|
||||
pub languages: Vec<String>,
|
||||
}
|
||||
|
||||
impl Default for OcrOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
detection_threshold: 0.5,
|
||||
recognition_threshold: 0.6,
|
||||
enable_math: true,
|
||||
decoder_type: DecoderType::BeamSearch,
|
||||
beam_width: 5,
|
||||
batch_size: 1,
|
||||
use_gpu: false,
|
||||
languages: vec!["en".to_string()],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Decoder type selection
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum DecoderType {
|
||||
/// Beam search decoder (higher quality, slower)
|
||||
BeamSearch,
|
||||
/// Greedy decoder (faster, lower quality)
|
||||
Greedy,
|
||||
/// CTC decoder for sequence-to-sequence models
|
||||
CTC,
|
||||
}
|
||||
|
||||
/// OCR result containing recognized text and metadata
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OcrResult {
|
||||
/// Recognized text
|
||||
pub text: String,
|
||||
|
||||
/// Overall confidence score (0.0-1.0)
|
||||
pub confidence: f32,
|
||||
|
||||
/// Detected text regions with their bounding boxes
|
||||
pub regions: Vec<TextRegion>,
|
||||
|
||||
/// Whether mathematical expressions were detected
|
||||
pub has_math: bool,
|
||||
|
||||
/// Processing time in milliseconds
|
||||
pub processing_time_ms: u64,
|
||||
}
|
||||
|
||||
/// A detected text region with position and content
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TextRegion {
|
||||
/// Bounding box coordinates [x, y, width, height]
|
||||
pub bbox: [f32; 4],
|
||||
|
||||
/// Recognized text in this region
|
||||
pub text: String,
|
||||
|
||||
/// Confidence score for this region (0.0-1.0)
|
||||
pub confidence: f32,
|
||||
|
||||
/// Region type (text, math, etc.)
|
||||
pub region_type: RegionType,
|
||||
|
||||
/// Character-level details if available
|
||||
pub characters: Vec<Character>,
|
||||
}
|
||||
|
||||
/// Type of text region
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum RegionType {
|
||||
/// Regular text
|
||||
Text,
|
||||
/// Mathematical expression
|
||||
Math,
|
||||
/// Diagram or figure
|
||||
Diagram,
|
||||
/// Table
|
||||
Table,
|
||||
/// Unknown type
|
||||
Unknown,
|
||||
}
|
||||
|
||||
/// Individual character with position and confidence
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Character {
|
||||
/// The character
|
||||
pub char: char,
|
||||
|
||||
/// Confidence score (0.0-1.0)
|
||||
pub confidence: f32,
|
||||
|
||||
/// Bounding box if available
|
||||
pub bbox: Option<[f32; 4]>,
|
||||
}
|
||||
|
||||
/// Error types for OCR operations
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum OcrError {
|
||||
#[error("Model loading error: {0}")]
|
||||
ModelLoading(String),
|
||||
|
||||
#[error("Inference error: {0}")]
|
||||
Inference(String),
|
||||
|
||||
#[error("Image processing error: {0}")]
|
||||
ImageProcessing(String),
|
||||
|
||||
#[error("Decoding error: {0}")]
|
||||
Decoding(String),
|
||||
|
||||
#[error("Invalid configuration: {0}")]
|
||||
InvalidConfig(String),
|
||||
|
||||
#[error("I/O error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("ONNX Runtime error: {0}")]
|
||||
OnnxRuntime(String),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, OcrError>;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ocr_options_default() {
|
||||
let options = OcrOptions::default();
|
||||
assert_eq!(options.detection_threshold, 0.5);
|
||||
assert_eq!(options.recognition_threshold, 0.6);
|
||||
assert!(options.enable_math);
|
||||
assert_eq!(options.decoder_type, DecoderType::BeamSearch);
|
||||
assert_eq!(options.beam_width, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_text_region_creation() {
|
||||
let region = TextRegion {
|
||||
bbox: [10.0, 20.0, 100.0, 30.0],
|
||||
text: "Test".to_string(),
|
||||
confidence: 0.95,
|
||||
region_type: RegionType::Text,
|
||||
characters: vec![],
|
||||
};
|
||||
assert_eq!(region.bbox[0], 10.0);
|
||||
assert_eq!(region.text, "Test");
|
||||
assert_eq!(region.region_type, RegionType::Text);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decoder_type_equality() {
|
||||
assert_eq!(DecoderType::BeamSearch, DecoderType::BeamSearch);
|
||||
assert_ne!(DecoderType::BeamSearch, DecoderType::Greedy);
|
||||
assert_ne!(DecoderType::Greedy, DecoderType::CTC);
|
||||
}
|
||||
}
|
||||
373
vendor/ruvector/examples/scipix/src/ocr/models.rs
vendored
Normal file
373
vendor/ruvector/examples/scipix/src/ocr/models.rs
vendored
Normal file
@@ -0,0 +1,373 @@
|
||||
//! Model Management Module
|
||||
//!
|
||||
//! This module handles loading, caching, and managing ONNX models for OCR.
|
||||
//! It supports lazy loading, model downloading with progress tracking,
|
||||
//! and checksum verification.
|
||||
|
||||
use super::{OcrError, Result};
|
||||
use dashmap::DashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
use ort::session::Session;
|
||||
#[cfg(feature = "ocr")]
|
||||
use parking_lot::Mutex;
|
||||
|
||||
/// Model types supported by the OCR engine
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum ModelType {
|
||||
/// Text detection model (finds text regions in images)
|
||||
Detection,
|
||||
/// Text recognition model (recognizes characters in regions)
|
||||
Recognition,
|
||||
/// Math expression recognition model
|
||||
Math,
|
||||
}
|
||||
|
||||
/// Handle to a loaded ONNX model
|
||||
#[derive(Clone)]
|
||||
pub struct ModelHandle {
|
||||
/// Model type
|
||||
model_type: ModelType,
|
||||
/// Path to the model file
|
||||
path: PathBuf,
|
||||
/// Model metadata
|
||||
metadata: ModelMetadata,
|
||||
/// ONNX Runtime session (when ocr feature is enabled)
|
||||
/// Wrapped in Mutex for mutable access required by ort 2.0 Session::run
|
||||
#[cfg(feature = "ocr")]
|
||||
session: Option<Arc<Mutex<Session>>>,
|
||||
/// Mock session for when ocr feature is disabled
|
||||
#[cfg(not(feature = "ocr"))]
|
||||
#[allow(dead_code)]
|
||||
session: Option<()>,
|
||||
}
|
||||
|
||||
impl ModelHandle {
|
||||
/// Create a new model handle
|
||||
pub fn new(model_type: ModelType, path: PathBuf, metadata: ModelMetadata) -> Result<Self> {
|
||||
debug!("Creating model handle for {:?} at {:?}", model_type, path);
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
let session = if path.exists() {
|
||||
match Session::builder() {
|
||||
Ok(builder) => match builder.commit_from_file(&path) {
|
||||
Ok(session) => {
|
||||
info!("Successfully loaded ONNX model: {:?}", path);
|
||||
Some(Arc::new(Mutex::new(session)))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to load ONNX model {:?}: {}", path, e);
|
||||
None
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
warn!("Failed to create ONNX session builder: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug!("Model file not found: {:?}", path);
|
||||
None
|
||||
};
|
||||
|
||||
#[cfg(not(feature = "ocr"))]
|
||||
let session: Option<()> = None;
|
||||
|
||||
Ok(Self {
|
||||
model_type,
|
||||
path,
|
||||
metadata,
|
||||
session,
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if the model session is loaded
|
||||
pub fn is_loaded(&self) -> bool {
|
||||
self.session.is_some()
|
||||
}
|
||||
|
||||
/// Get the ONNX session (only available with ocr feature)
|
||||
#[cfg(feature = "ocr")]
|
||||
pub fn session(&self) -> Option<&Arc<Mutex<Session>>> {
|
||||
self.session.as_ref()
|
||||
}
|
||||
|
||||
/// Get the model type
|
||||
pub fn model_type(&self) -> ModelType {
|
||||
self.model_type
|
||||
}
|
||||
|
||||
/// Get the model path
|
||||
pub fn path(&self) -> &Path {
|
||||
&self.path
|
||||
}
|
||||
|
||||
/// Get model metadata
|
||||
pub fn metadata(&self) -> &ModelMetadata {
|
||||
&self.metadata
|
||||
}
|
||||
|
||||
/// Get input shape for the model
|
||||
pub fn input_shape(&self) -> &[usize] {
|
||||
&self.metadata.input_shape
|
||||
}
|
||||
|
||||
/// Get output shape for the model
|
||||
pub fn output_shape(&self) -> &[usize] {
|
||||
&self.metadata.output_shape
|
||||
}
|
||||
}
|
||||
|
||||
/// Model metadata
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelMetadata {
|
||||
/// Model name
|
||||
pub name: String,
|
||||
/// Model version
|
||||
pub version: String,
|
||||
/// Input tensor shape
|
||||
pub input_shape: Vec<usize>,
|
||||
/// Output tensor shape
|
||||
pub output_shape: Vec<usize>,
|
||||
/// Expected input data type
|
||||
pub input_dtype: String,
|
||||
/// File size in bytes
|
||||
pub file_size: u64,
|
||||
/// SHA256 checksum
|
||||
pub checksum: Option<String>,
|
||||
}
|
||||
|
||||
/// Model registry for loading and caching models
|
||||
pub struct ModelRegistry {
|
||||
/// Cache of loaded models
|
||||
cache: DashMap<ModelType, Arc<ModelHandle>>,
|
||||
/// Base directory for models
|
||||
model_dir: PathBuf,
|
||||
/// Whether to enable lazy loading
|
||||
lazy_loading: bool,
|
||||
}
|
||||
|
||||
impl ModelRegistry {
|
||||
/// Create a new model registry
|
||||
pub fn new() -> Self {
|
||||
Self::with_model_dir(PathBuf::from("./models"))
|
||||
}
|
||||
|
||||
/// Create a new model registry with custom model directory
|
||||
pub fn with_model_dir(model_dir: PathBuf) -> Self {
|
||||
info!("Initializing model registry at {:?}", model_dir);
|
||||
Self {
|
||||
cache: DashMap::new(),
|
||||
model_dir,
|
||||
lazy_loading: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Load the detection model
|
||||
pub async fn load_detection_model(&mut self) -> Result<Arc<ModelHandle>> {
|
||||
self.load_model(ModelType::Detection).await
|
||||
}
|
||||
|
||||
/// Load the recognition model
|
||||
pub async fn load_recognition_model(&mut self) -> Result<Arc<ModelHandle>> {
|
||||
self.load_model(ModelType::Recognition).await
|
||||
}
|
||||
|
||||
/// Load the math recognition model
|
||||
pub async fn load_math_model(&mut self) -> Result<Arc<ModelHandle>> {
|
||||
self.load_model(ModelType::Math).await
|
||||
}
|
||||
|
||||
/// Load a model by type
|
||||
pub async fn load_model(&mut self, model_type: ModelType) -> Result<Arc<ModelHandle>> {
|
||||
// Check cache first
|
||||
if let Some(handle) = self.cache.get(&model_type) {
|
||||
debug!("Model {:?} found in cache", model_type);
|
||||
return Ok(Arc::clone(handle.value()));
|
||||
}
|
||||
|
||||
info!("Loading model {:?}...", model_type);
|
||||
|
||||
// Get model path
|
||||
let model_path = self.get_model_path(model_type);
|
||||
|
||||
// Check if model exists
|
||||
if !model_path.exists() {
|
||||
if self.lazy_loading {
|
||||
warn!(
|
||||
"Model {:?} not found at {:?}. OCR will not work without models.",
|
||||
model_type, model_path
|
||||
);
|
||||
warn!("Download models from: https://github.com/PaddlePaddle/PaddleOCR or configure custom models.");
|
||||
} else {
|
||||
return Err(OcrError::ModelLoading(format!(
|
||||
"Model {:?} not found at {:?}",
|
||||
model_type, model_path
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Load model metadata
|
||||
let metadata = self.get_model_metadata(model_type);
|
||||
|
||||
// Verify checksum if provided
|
||||
if let Some(ref checksum) = metadata.checksum {
|
||||
if model_path.exists() {
|
||||
debug!("Verifying model checksum: {}", checksum);
|
||||
// In production: verify_checksum(&model_path, checksum)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Create model handle (will load ONNX session if file exists)
|
||||
let handle = Arc::new(ModelHandle::new(model_type, model_path, metadata)?);
|
||||
|
||||
// Cache the handle
|
||||
self.cache.insert(model_type, Arc::clone(&handle));
|
||||
|
||||
if handle.is_loaded() {
|
||||
info!(
|
||||
"Model {:?} loaded successfully with ONNX session",
|
||||
model_type
|
||||
);
|
||||
} else {
|
||||
warn!(
|
||||
"Model {:?} handle created but ONNX session not loaded",
|
||||
model_type
|
||||
);
|
||||
}
|
||||
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
/// Get the file path for a model type
|
||||
fn get_model_path(&self, model_type: ModelType) -> PathBuf {
|
||||
let filename = match model_type {
|
||||
ModelType::Detection => "text_detection.onnx",
|
||||
ModelType::Recognition => "text_recognition.onnx",
|
||||
ModelType::Math => "math_recognition.onnx",
|
||||
};
|
||||
self.model_dir.join(filename)
|
||||
}
|
||||
|
||||
/// Get default metadata for a model type
|
||||
fn get_model_metadata(&self, model_type: ModelType) -> ModelMetadata {
|
||||
match model_type {
|
||||
ModelType::Detection => ModelMetadata {
|
||||
name: "Text Detection".to_string(),
|
||||
version: "1.0.0".to_string(),
|
||||
input_shape: vec![1, 3, 640, 640], // NCHW format
|
||||
output_shape: vec![1, 25200, 85], // Detections
|
||||
input_dtype: "float32".to_string(),
|
||||
file_size: 50_000_000, // ~50MB
|
||||
checksum: None,
|
||||
},
|
||||
ModelType::Recognition => ModelMetadata {
|
||||
name: "Text Recognition".to_string(),
|
||||
version: "1.0.0".to_string(),
|
||||
input_shape: vec![1, 1, 32, 128], // NCHW format
|
||||
output_shape: vec![1, 26, 37], // Sequence length, vocab size
|
||||
input_dtype: "float32".to_string(),
|
||||
file_size: 20_000_000, // ~20MB
|
||||
checksum: None,
|
||||
},
|
||||
ModelType::Math => ModelMetadata {
|
||||
name: "Math Recognition".to_string(),
|
||||
version: "1.0.0".to_string(),
|
||||
input_shape: vec![1, 1, 64, 256], // NCHW format
|
||||
output_shape: vec![1, 50, 512], // Sequence length, vocab size
|
||||
input_dtype: "float32".to_string(),
|
||||
file_size: 80_000_000, // ~80MB
|
||||
checksum: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear the model cache
|
||||
pub fn clear_cache(&mut self) {
|
||||
info!("Clearing model cache");
|
||||
self.cache.clear();
|
||||
}
|
||||
|
||||
/// Get a cached model if available
|
||||
pub fn get_cached(&self, model_type: ModelType) -> Option<Arc<ModelHandle>> {
|
||||
self.cache.get(&model_type).map(|h| Arc::clone(h.value()))
|
||||
}
|
||||
|
||||
/// Set lazy loading mode
|
||||
pub fn set_lazy_loading(&mut self, enabled: bool) {
|
||||
self.lazy_loading = enabled;
|
||||
}
|
||||
|
||||
/// Get the model directory
|
||||
pub fn model_dir(&self) -> &Path {
|
||||
&self.model_dir
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ModelRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_model_registry_creation() {
|
||||
let registry = ModelRegistry::new();
|
||||
assert_eq!(registry.model_dir(), Path::new("./models"));
|
||||
assert!(registry.lazy_loading);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_path_generation() {
|
||||
let registry = ModelRegistry::new();
|
||||
let path = registry.get_model_path(ModelType::Detection);
|
||||
assert!(path.to_string_lossy().contains("text_detection.onnx"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_metadata() {
|
||||
let registry = ModelRegistry::new();
|
||||
let metadata = registry.get_model_metadata(ModelType::Recognition);
|
||||
assert_eq!(metadata.name, "Text Recognition");
|
||||
assert_eq!(metadata.version, "1.0.0");
|
||||
assert_eq!(metadata.input_shape, vec![1, 1, 32, 128]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_model_caching() {
|
||||
let mut registry = ModelRegistry::new();
|
||||
let model1 = registry.load_detection_model().await.unwrap();
|
||||
let model2 = registry.load_detection_model().await.unwrap();
|
||||
assert!(Arc::ptr_eq(&model1, &model2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear_cache() {
|
||||
let mut registry = ModelRegistry::new();
|
||||
registry.clear_cache();
|
||||
assert_eq!(registry.cache.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_handle_without_file() {
|
||||
let path = PathBuf::from("/nonexistent/model.onnx");
|
||||
let metadata = ModelMetadata {
|
||||
name: "Test".to_string(),
|
||||
version: "1.0.0".to_string(),
|
||||
input_shape: vec![1, 3, 640, 640],
|
||||
output_shape: vec![1, 100, 85],
|
||||
input_dtype: "float32".to_string(),
|
||||
file_size: 1000,
|
||||
checksum: None,
|
||||
};
|
||||
let handle = ModelHandle::new(ModelType::Detection, path, metadata).unwrap();
|
||||
assert!(!handle.is_loaded());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user