Files
wifi-densepose/vendor/ruvector/examples/scipix/src/ocr/engine.rs

364 lines
11 KiB
Rust

//! OCR Engine Implementation
//!
//! This module provides the main OcrEngine for orchestrating OCR operations.
//! It handles model loading, inference coordination, and result assembly.
use super::{
confidence::aggregate_confidence,
decoder::{BeamSearchDecoder, CTCDecoder, Decoder, GreedyDecoder, Vocabulary},
inference::{DetectionResult, InferenceEngine, RecognitionResult},
models::{ModelHandle, ModelRegistry},
Character, DecoderType, OcrError, OcrOptions, OcrResult, RegionType, Result, TextRegion,
};
use parking_lot::RwLock;
use std::sync::Arc;
use std::time::Instant;
use tracing::{debug, info, warn};
/// OCR processor trait for custom implementations
pub trait OcrProcessor: Send + Sync {
/// Process an image and return OCR results
fn process(&self, image_data: &[u8], options: &OcrOptions) -> Result<OcrResult>;
/// Batch process multiple images
fn process_batch(&self, images: &[&[u8]], options: &OcrOptions) -> Result<Vec<OcrResult>>;
}
/// Main OCR Engine with thread-safe model management
pub struct OcrEngine {
/// Model registry for loading and caching models
registry: Arc<RwLock<ModelRegistry>>,
/// Inference engine for running ONNX models
inference: Arc<InferenceEngine>,
/// Default OCR options
default_options: OcrOptions,
/// Vocabulary for decoding
vocabulary: Arc<Vocabulary>,
/// Whether the engine is warmed up
warmed_up: Arc<RwLock<bool>>,
}
impl OcrEngine {
/// Create a new OCR engine with default models
///
/// # Example
///
/// ```no_run
/// # use ruvector_scipix::ocr::OcrEngine;
/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
/// let engine = OcrEngine::new().await?;
/// # Ok(())
/// # }
/// ```
pub async fn new() -> Result<Self> {
Self::with_options(OcrOptions::default()).await
}
/// Create a new OCR engine with custom options
pub async fn with_options(options: OcrOptions) -> Result<Self> {
info!("Initializing OCR engine with options: {:?}", options);
// Initialize model registry
let registry = Arc::new(RwLock::new(ModelRegistry::new()));
// Load default models (in production, these would be downloaded/cached)
debug!("Loading detection model...");
let detection_model = registry.write().load_detection_model().await.map_err(|e| {
OcrError::ModelLoading(format!("Failed to load detection model: {}", e))
})?;
debug!("Loading recognition model...");
let recognition_model = registry
.write()
.load_recognition_model()
.await
.map_err(|e| {
OcrError::ModelLoading(format!("Failed to load recognition model: {}", e))
})?;
let math_model =
if options.enable_math {
debug!("Loading math recognition model...");
Some(registry.write().load_math_model().await.map_err(|e| {
OcrError::ModelLoading(format!("Failed to load math model: {}", e))
})?)
} else {
None
};
// Create inference engine
let inference = Arc::new(InferenceEngine::new(
detection_model,
recognition_model,
math_model,
options.use_gpu,
)?);
// Load vocabulary
let vocabulary = Arc::new(Vocabulary::default());
let engine = Self {
registry,
inference,
default_options: options,
vocabulary,
warmed_up: Arc::new(RwLock::new(false)),
};
info!("OCR engine initialized successfully");
Ok(engine)
}
/// Warm up the engine by running a dummy inference
///
/// This helps reduce latency for the first real inference by initializing
/// all ONNX runtime resources.
pub async fn warmup(&self) -> Result<()> {
if *self.warmed_up.read() {
debug!("Engine already warmed up, skipping");
return Ok(());
}
info!("Warming up OCR engine...");
let start = Instant::now();
// Create a small dummy image (100x100 black image)
let dummy_image = vec![0u8; 100 * 100 * 3];
// Run a dummy inference
let _ = self.recognize(&dummy_image).await;
*self.warmed_up.write() = true;
info!("Engine warmup completed in {:?}", start.elapsed());
Ok(())
}
/// Recognize text in an image using default options
pub async fn recognize(&self, image_data: &[u8]) -> Result<OcrResult> {
self.recognize_with_options(image_data, &self.default_options)
.await
}
/// Recognize text in an image with custom options
pub async fn recognize_with_options(
&self,
image_data: &[u8],
options: &OcrOptions,
) -> Result<OcrResult> {
let start = Instant::now();
debug!("Starting OCR recognition");
// Step 1: Run text detection
debug!("Running text detection...");
let detection_results = self
.inference
.run_detection(image_data, options.detection_threshold)
.await?;
debug!("Detected {} regions", detection_results.len());
if detection_results.is_empty() {
warn!("No text regions detected");
return Ok(OcrResult {
text: String::new(),
confidence: 0.0,
regions: vec![],
has_math: false,
processing_time_ms: start.elapsed().as_millis() as u64,
});
}
// Step 2: Run recognition on each detected region
debug!("Running text recognition...");
let mut text_regions = Vec::new();
let mut has_math = false;
for detection in detection_results {
// Determine region type
let region_type = if options.enable_math && detection.is_math_likely {
has_math = true;
RegionType::Math
} else {
RegionType::Text
};
// Run appropriate recognition
let recognition = if region_type == RegionType::Math {
self.inference
.run_math_recognition(&detection.region_image, options)
.await?
} else {
self.inference
.run_recognition(&detection.region_image, options)
.await?
};
// Decode the recognition output
let decoded_text = self.decode_output(&recognition, options)?;
// Calculate confidence
let confidence = aggregate_confidence(&recognition.character_confidences);
// Filter by confidence threshold
if confidence < options.recognition_threshold {
debug!(
"Skipping region with low confidence: {:.2} < {:.2}",
confidence, options.recognition_threshold
);
continue;
}
// Build character list
let characters = decoded_text
.chars()
.zip(recognition.character_confidences.iter())
.map(|(ch, &conf)| Character {
char: ch,
confidence: conf,
bbox: None, // Could be populated if available from model
})
.collect();
text_regions.push(TextRegion {
bbox: detection.bbox,
text: decoded_text,
confidence,
region_type,
characters,
});
}
// Step 3: Combine results
let combined_text = text_regions
.iter()
.map(|r| r.text.as_str())
.collect::<Vec<_>>()
.join(" ");
let overall_confidence = if text_regions.is_empty() {
0.0
} else {
text_regions.iter().map(|r| r.confidence).sum::<f32>() / text_regions.len() as f32
};
let processing_time_ms = start.elapsed().as_millis() as u64;
debug!(
"OCR completed in {}ms, recognized {} regions",
processing_time_ms,
text_regions.len()
);
Ok(OcrResult {
text: combined_text,
confidence: overall_confidence,
regions: text_regions,
has_math,
processing_time_ms,
})
}
/// Batch process multiple images
pub async fn recognize_batch(
&self,
images: &[&[u8]],
options: &OcrOptions,
) -> Result<Vec<OcrResult>> {
info!("Processing batch of {} images", images.len());
let start = Instant::now();
// Process images in parallel using rayon
let results: Result<Vec<OcrResult>> = images
.iter()
.map(|image_data| {
// Note: In a real async implementation, we'd use tokio::spawn
// For now, we'll use blocking since we're in a sync context
futures::executor::block_on(self.recognize_with_options(image_data, options))
})
.collect();
info!("Batch processing completed in {:?}", start.elapsed());
results
}
/// Decode recognition output using the selected decoder
fn decode_output(
&self,
recognition: &RecognitionResult,
options: &OcrOptions,
) -> Result<String> {
debug!("Decoding output with {:?} decoder", options.decoder_type);
let decoded = match options.decoder_type {
DecoderType::BeamSearch => {
let decoder = BeamSearchDecoder::new(self.vocabulary.clone(), options.beam_width);
decoder.decode(&recognition.logits)?
}
DecoderType::Greedy => {
let decoder = GreedyDecoder::new(self.vocabulary.clone());
decoder.decode(&recognition.logits)?
}
DecoderType::CTC => {
let decoder = CTCDecoder::new(self.vocabulary.clone());
decoder.decode(&recognition.logits)?
}
};
Ok(decoded)
}
/// Get the current model registry
pub fn registry(&self) -> Arc<RwLock<ModelRegistry>> {
Arc::clone(&self.registry)
}
/// Get the default options
pub fn default_options(&self) -> &OcrOptions {
&self.default_options
}
/// Check if engine is warmed up
pub fn is_warmed_up(&self) -> bool {
*self.warmed_up.read()
}
}
impl OcrProcessor for OcrEngine {
fn process(&self, image_data: &[u8], options: &OcrOptions) -> Result<OcrResult> {
// Blocking wrapper for async method
futures::executor::block_on(self.recognize_with_options(image_data, options))
}
fn process_batch(&self, images: &[&[u8]], options: &OcrOptions) -> Result<Vec<OcrResult>> {
// Blocking wrapper for async method
futures::executor::block_on(self.recognize_batch(images, options))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decoder_selection() {
let options = OcrOptions {
decoder_type: DecoderType::BeamSearch,
..Default::default()
};
assert_eq!(options.decoder_type, DecoderType::BeamSearch);
}
#[test]
fn test_warmup_flag() {
let flag = Arc::new(RwLock::new(false));
assert!(!*flag.read());
*flag.write() = true;
assert!(*flag.read());
}
}