364 lines
11 KiB
Rust
364 lines
11 KiB
Rust
//! OCR Engine Implementation
|
|
//!
|
|
//! This module provides the main OcrEngine for orchestrating OCR operations.
|
|
//! It handles model loading, inference coordination, and result assembly.
|
|
|
|
use super::{
|
|
confidence::aggregate_confidence,
|
|
decoder::{BeamSearchDecoder, CTCDecoder, Decoder, GreedyDecoder, Vocabulary},
|
|
inference::{DetectionResult, InferenceEngine, RecognitionResult},
|
|
models::{ModelHandle, ModelRegistry},
|
|
Character, DecoderType, OcrError, OcrOptions, OcrResult, RegionType, Result, TextRegion,
|
|
};
|
|
use parking_lot::RwLock;
|
|
use std::sync::Arc;
|
|
use std::time::Instant;
|
|
use tracing::{debug, info, warn};
|
|
|
|
/// OCR processor trait for custom implementations
|
|
pub trait OcrProcessor: Send + Sync {
|
|
/// Process an image and return OCR results
|
|
fn process(&self, image_data: &[u8], options: &OcrOptions) -> Result<OcrResult>;
|
|
|
|
/// Batch process multiple images
|
|
fn process_batch(&self, images: &[&[u8]], options: &OcrOptions) -> Result<Vec<OcrResult>>;
|
|
}
|
|
|
|
/// Main OCR Engine with thread-safe model management
|
|
pub struct OcrEngine {
|
|
/// Model registry for loading and caching models
|
|
registry: Arc<RwLock<ModelRegistry>>,
|
|
|
|
/// Inference engine for running ONNX models
|
|
inference: Arc<InferenceEngine>,
|
|
|
|
/// Default OCR options
|
|
default_options: OcrOptions,
|
|
|
|
/// Vocabulary for decoding
|
|
vocabulary: Arc<Vocabulary>,
|
|
|
|
/// Whether the engine is warmed up
|
|
warmed_up: Arc<RwLock<bool>>,
|
|
}
|
|
|
|
impl OcrEngine {
|
|
/// Create a new OCR engine with default models
|
|
///
|
|
/// # Example
|
|
///
|
|
/// ```no_run
|
|
/// # use ruvector_scipix::ocr::OcrEngine;
|
|
/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
|
|
/// let engine = OcrEngine::new().await?;
|
|
/// # Ok(())
|
|
/// # }
|
|
/// ```
|
|
pub async fn new() -> Result<Self> {
|
|
Self::with_options(OcrOptions::default()).await
|
|
}
|
|
|
|
/// Create a new OCR engine with custom options
|
|
pub async fn with_options(options: OcrOptions) -> Result<Self> {
|
|
info!("Initializing OCR engine with options: {:?}", options);
|
|
|
|
// Initialize model registry
|
|
let registry = Arc::new(RwLock::new(ModelRegistry::new()));
|
|
|
|
// Load default models (in production, these would be downloaded/cached)
|
|
debug!("Loading detection model...");
|
|
let detection_model = registry.write().load_detection_model().await.map_err(|e| {
|
|
OcrError::ModelLoading(format!("Failed to load detection model: {}", e))
|
|
})?;
|
|
|
|
debug!("Loading recognition model...");
|
|
let recognition_model = registry
|
|
.write()
|
|
.load_recognition_model()
|
|
.await
|
|
.map_err(|e| {
|
|
OcrError::ModelLoading(format!("Failed to load recognition model: {}", e))
|
|
})?;
|
|
|
|
let math_model =
|
|
if options.enable_math {
|
|
debug!("Loading math recognition model...");
|
|
Some(registry.write().load_math_model().await.map_err(|e| {
|
|
OcrError::ModelLoading(format!("Failed to load math model: {}", e))
|
|
})?)
|
|
} else {
|
|
None
|
|
};
|
|
|
|
// Create inference engine
|
|
let inference = Arc::new(InferenceEngine::new(
|
|
detection_model,
|
|
recognition_model,
|
|
math_model,
|
|
options.use_gpu,
|
|
)?);
|
|
|
|
// Load vocabulary
|
|
let vocabulary = Arc::new(Vocabulary::default());
|
|
|
|
let engine = Self {
|
|
registry,
|
|
inference,
|
|
default_options: options,
|
|
vocabulary,
|
|
warmed_up: Arc::new(RwLock::new(false)),
|
|
};
|
|
|
|
info!("OCR engine initialized successfully");
|
|
Ok(engine)
|
|
}
|
|
|
|
/// Warm up the engine by running a dummy inference
|
|
///
|
|
/// This helps reduce latency for the first real inference by initializing
|
|
/// all ONNX runtime resources.
|
|
pub async fn warmup(&self) -> Result<()> {
|
|
if *self.warmed_up.read() {
|
|
debug!("Engine already warmed up, skipping");
|
|
return Ok(());
|
|
}
|
|
|
|
info!("Warming up OCR engine...");
|
|
let start = Instant::now();
|
|
|
|
// Create a small dummy image (100x100 black image)
|
|
let dummy_image = vec![0u8; 100 * 100 * 3];
|
|
|
|
// Run a dummy inference
|
|
let _ = self.recognize(&dummy_image).await;
|
|
|
|
*self.warmed_up.write() = true;
|
|
info!("Engine warmup completed in {:?}", start.elapsed());
|
|
Ok(())
|
|
}
|
|
|
|
/// Recognize text in an image using default options
|
|
pub async fn recognize(&self, image_data: &[u8]) -> Result<OcrResult> {
|
|
self.recognize_with_options(image_data, &self.default_options)
|
|
.await
|
|
}
|
|
|
|
/// Recognize text in an image with custom options
|
|
pub async fn recognize_with_options(
|
|
&self,
|
|
image_data: &[u8],
|
|
options: &OcrOptions,
|
|
) -> Result<OcrResult> {
|
|
let start = Instant::now();
|
|
debug!("Starting OCR recognition");
|
|
|
|
// Step 1: Run text detection
|
|
debug!("Running text detection...");
|
|
let detection_results = self
|
|
.inference
|
|
.run_detection(image_data, options.detection_threshold)
|
|
.await?;
|
|
|
|
debug!("Detected {} regions", detection_results.len());
|
|
|
|
if detection_results.is_empty() {
|
|
warn!("No text regions detected");
|
|
return Ok(OcrResult {
|
|
text: String::new(),
|
|
confidence: 0.0,
|
|
regions: vec![],
|
|
has_math: false,
|
|
processing_time_ms: start.elapsed().as_millis() as u64,
|
|
});
|
|
}
|
|
|
|
// Step 2: Run recognition on each detected region
|
|
debug!("Running text recognition...");
|
|
let mut text_regions = Vec::new();
|
|
let mut has_math = false;
|
|
|
|
for detection in detection_results {
|
|
// Determine region type
|
|
let region_type = if options.enable_math && detection.is_math_likely {
|
|
has_math = true;
|
|
RegionType::Math
|
|
} else {
|
|
RegionType::Text
|
|
};
|
|
|
|
// Run appropriate recognition
|
|
let recognition = if region_type == RegionType::Math {
|
|
self.inference
|
|
.run_math_recognition(&detection.region_image, options)
|
|
.await?
|
|
} else {
|
|
self.inference
|
|
.run_recognition(&detection.region_image, options)
|
|
.await?
|
|
};
|
|
|
|
// Decode the recognition output
|
|
let decoded_text = self.decode_output(&recognition, options)?;
|
|
|
|
// Calculate confidence
|
|
let confidence = aggregate_confidence(&recognition.character_confidences);
|
|
|
|
// Filter by confidence threshold
|
|
if confidence < options.recognition_threshold {
|
|
debug!(
|
|
"Skipping region with low confidence: {:.2} < {:.2}",
|
|
confidence, options.recognition_threshold
|
|
);
|
|
continue;
|
|
}
|
|
|
|
// Build character list
|
|
let characters = decoded_text
|
|
.chars()
|
|
.zip(recognition.character_confidences.iter())
|
|
.map(|(ch, &conf)| Character {
|
|
char: ch,
|
|
confidence: conf,
|
|
bbox: None, // Could be populated if available from model
|
|
})
|
|
.collect();
|
|
|
|
text_regions.push(TextRegion {
|
|
bbox: detection.bbox,
|
|
text: decoded_text,
|
|
confidence,
|
|
region_type,
|
|
characters,
|
|
});
|
|
}
|
|
|
|
// Step 3: Combine results
|
|
let combined_text = text_regions
|
|
.iter()
|
|
.map(|r| r.text.as_str())
|
|
.collect::<Vec<_>>()
|
|
.join(" ");
|
|
|
|
let overall_confidence = if text_regions.is_empty() {
|
|
0.0
|
|
} else {
|
|
text_regions.iter().map(|r| r.confidence).sum::<f32>() / text_regions.len() as f32
|
|
};
|
|
|
|
let processing_time_ms = start.elapsed().as_millis() as u64;
|
|
|
|
debug!(
|
|
"OCR completed in {}ms, recognized {} regions",
|
|
processing_time_ms,
|
|
text_regions.len()
|
|
);
|
|
|
|
Ok(OcrResult {
|
|
text: combined_text,
|
|
confidence: overall_confidence,
|
|
regions: text_regions,
|
|
has_math,
|
|
processing_time_ms,
|
|
})
|
|
}
|
|
|
|
/// Batch process multiple images
|
|
pub async fn recognize_batch(
|
|
&self,
|
|
images: &[&[u8]],
|
|
options: &OcrOptions,
|
|
) -> Result<Vec<OcrResult>> {
|
|
info!("Processing batch of {} images", images.len());
|
|
let start = Instant::now();
|
|
|
|
// Process images in parallel using rayon
|
|
let results: Result<Vec<OcrResult>> = images
|
|
.iter()
|
|
.map(|image_data| {
|
|
// Note: In a real async implementation, we'd use tokio::spawn
|
|
// For now, we'll use blocking since we're in a sync context
|
|
futures::executor::block_on(self.recognize_with_options(image_data, options))
|
|
})
|
|
.collect();
|
|
|
|
info!("Batch processing completed in {:?}", start.elapsed());
|
|
|
|
results
|
|
}
|
|
|
|
/// Decode recognition output using the selected decoder
|
|
fn decode_output(
|
|
&self,
|
|
recognition: &RecognitionResult,
|
|
options: &OcrOptions,
|
|
) -> Result<String> {
|
|
debug!("Decoding output with {:?} decoder", options.decoder_type);
|
|
|
|
let decoded = match options.decoder_type {
|
|
DecoderType::BeamSearch => {
|
|
let decoder = BeamSearchDecoder::new(self.vocabulary.clone(), options.beam_width);
|
|
decoder.decode(&recognition.logits)?
|
|
}
|
|
DecoderType::Greedy => {
|
|
let decoder = GreedyDecoder::new(self.vocabulary.clone());
|
|
decoder.decode(&recognition.logits)?
|
|
}
|
|
DecoderType::CTC => {
|
|
let decoder = CTCDecoder::new(self.vocabulary.clone());
|
|
decoder.decode(&recognition.logits)?
|
|
}
|
|
};
|
|
|
|
Ok(decoded)
|
|
}
|
|
|
|
/// Get the current model registry
|
|
pub fn registry(&self) -> Arc<RwLock<ModelRegistry>> {
|
|
Arc::clone(&self.registry)
|
|
}
|
|
|
|
/// Get the default options
|
|
pub fn default_options(&self) -> &OcrOptions {
|
|
&self.default_options
|
|
}
|
|
|
|
/// Check if engine is warmed up
|
|
pub fn is_warmed_up(&self) -> bool {
|
|
*self.warmed_up.read()
|
|
}
|
|
}
|
|
|
|
impl OcrProcessor for OcrEngine {
|
|
fn process(&self, image_data: &[u8], options: &OcrOptions) -> Result<OcrResult> {
|
|
// Blocking wrapper for async method
|
|
futures::executor::block_on(self.recognize_with_options(image_data, options))
|
|
}
|
|
|
|
fn process_batch(&self, images: &[&[u8]], options: &OcrOptions) -> Result<Vec<OcrResult>> {
|
|
// Blocking wrapper for async method
|
|
futures::executor::block_on(self.recognize_batch(images, options))
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_decoder_selection() {
|
|
let options = OcrOptions {
|
|
decoder_type: DecoderType::BeamSearch,
|
|
..Default::default()
|
|
};
|
|
assert_eq!(options.decoder_type, DecoderType::BeamSearch);
|
|
}
|
|
|
|
#[test]
|
|
fn test_warmup_flag() {
|
|
let flag = Arc::new(RwLock::new(false));
|
|
assert!(!*flag.read());
|
|
*flag.write() = true;
|
|
assert!(*flag.read());
|
|
}
|
|
}
|