Squashed 'vendor/ruvector/' content from commit b64c2172

git-subtree-dir: vendor/ruvector
git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
commit d803bfe2b1
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,308 @@
use axum::{
extract::{Path, Query, State},
http::StatusCode,
response::{sse::Event, IntoResponse, Sse},
Json,
};
use futures::stream::{self, Stream};
use serde::{Deserialize, Serialize};
use std::{convert::Infallible, time::Duration};
use tracing::{error, info, warn};
use validator::Validate;
use super::{
jobs::{JobStatus, PdfJob},
requests::{LatexRequest, PdfRequest, StrokesRequest, TextRequest},
responses::{ErrorResponse, PdfResponse, TextResponse},
state::AppState,
};
/// Health check handler
pub async fn get_health() -> impl IntoResponse {
#[derive(Serialize)]
struct Health {
status: &'static str,
version: &'static str,
}
Json(Health {
status: "ok",
version: env!("CARGO_PKG_VERSION"),
})
}
/// Process text/image OCR request
/// Supports multipart/form-data, base64, and URL inputs
///
/// # Important
/// This endpoint requires OCR models to be configured. If models are not available,
/// returns a 503 Service Unavailable error with instructions.
pub async fn process_text(
State(_state): State<AppState>,
Json(request): Json<TextRequest>,
) -> Result<Json<TextResponse>, ErrorResponse> {
info!("Processing text OCR request");
// Validate request
request.validate().map_err(|e| {
warn!("Invalid request: {:?}", e);
ErrorResponse::validation_error(format!("Validation failed: {}", e))
})?;
// Download or decode image
let image_data = match request.get_image_data().await {
Ok(data) => data,
Err(e) => {
error!("Failed to get image data: {:?}", e);
return Err(ErrorResponse::internal_error("Failed to process image"));
}
};
// Validate image data is not empty
if image_data.is_empty() {
return Err(ErrorResponse::validation_error("Image data is empty"));
}
// OCR processing requires models to be configured
// Return informative error explaining how to set up the service
Err(ErrorResponse::service_unavailable(
"OCR service not fully configured. ONNX models are required for OCR processing. \
Please download compatible models (PaddleOCR, TrOCR) and configure the model directory. \
See documentation at /docs/MODEL_SETUP.md for setup instructions.",
))
}
/// Process digital ink strokes
///
/// # Important
/// This endpoint requires OCR models to be configured.
pub async fn process_strokes(
State(_state): State<AppState>,
Json(request): Json<StrokesRequest>,
) -> Result<Json<TextResponse>, ErrorResponse> {
info!(
"Processing strokes request with {} strokes",
request.strokes.len()
);
request
.validate()
.map_err(|e| ErrorResponse::validation_error(format!("Validation failed: {}", e)))?;
// Validate we have stroke data
if request.strokes.is_empty() {
return Err(ErrorResponse::validation_error("No strokes provided"));
}
// Stroke recognition requires models to be configured
Err(ErrorResponse::service_unavailable(
"Stroke recognition service not configured. ONNX models required for ink recognition.",
))
}
/// Process legacy LaTeX equation request
///
/// # Important
/// This endpoint requires OCR models to be configured.
pub async fn process_latex(
State(_state): State<AppState>,
Json(request): Json<LatexRequest>,
) -> Result<Json<TextResponse>, ErrorResponse> {
info!("Processing legacy LaTeX request");
request
.validate()
.map_err(|e| ErrorResponse::validation_error(format!("Validation failed: {}", e)))?;
// LaTeX recognition requires models to be configured
Err(ErrorResponse::service_unavailable(
"LaTeX recognition service not configured. ONNX models required.",
))
}
/// Create async PDF processing job
pub async fn process_pdf(
State(state): State<AppState>,
Json(request): Json<PdfRequest>,
) -> Result<Json<PdfResponse>, ErrorResponse> {
info!("Creating PDF processing job");
request
.validate()
.map_err(|e| ErrorResponse::validation_error(format!("Validation failed: {}", e)))?;
// Create job
let job = PdfJob::new(request);
let job_id = job.id.clone();
// Queue job
state.job_queue.enqueue(job).await.map_err(|e| {
error!("Failed to enqueue job: {:?}", e);
ErrorResponse::internal_error("Failed to create PDF job")
})?;
let response = PdfResponse {
pdf_id: job_id,
status: JobStatus::Processing,
message: Some("PDF processing started".to_string()),
result: None,
error: None,
};
Ok(Json(response))
}
/// Get PDF job status
pub async fn get_pdf_status(
State(state): State<AppState>,
Path(id): Path<String>,
) -> Result<Json<PdfResponse>, ErrorResponse> {
info!("Getting PDF job status: {}", id);
let status = state
.job_queue
.get_status(&id)
.await
.ok_or_else(|| ErrorResponse::not_found("Job not found"))?;
let response = PdfResponse {
pdf_id: id.clone(),
status: status.clone(),
message: Some(format!("Job status: {:?}", status)),
result: state.job_queue.get_result(&id).await,
error: state.job_queue.get_error(&id).await,
};
Ok(Json(response))
}
/// Delete PDF job
pub async fn delete_pdf_job(
State(state): State<AppState>,
Path(id): Path<String>,
) -> Result<StatusCode, ErrorResponse> {
info!("Deleting PDF job: {}", id);
state
.job_queue
.cancel(&id)
.await
.map_err(|_| ErrorResponse::not_found("Job not found"))?;
Ok(StatusCode::NO_CONTENT)
}
/// Stream PDF processing results via SSE
pub async fn stream_pdf_results(
State(_state): State<AppState>,
Path(_id): Path<String>,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
info!("Streaming PDF results for job: {}", _id);
let stream = stream::unfold(0, move |page| {
async move {
if page > 10 {
// Example: stop after 10 pages
return None;
}
tokio::time::sleep(Duration::from_millis(500)).await;
let event = Event::default()
.json_data(serde_json::json!({
"page": page,
"text": format!("Content from page {}", page),
"progress": (page as f32 / 10.0) * 100.0
}))
.ok()?;
Some((Ok(event), page + 1))
}
});
Sse::new(stream)
}
/// Convert document to different format (MMD/DOCX/etc)
///
/// # Note
/// Document conversion requires additional backend services to be configured.
pub async fn convert_document(
State(_state): State<AppState>,
Json(_request): Json<serde_json::Value>,
) -> Result<Json<serde_json::Value>, ErrorResponse> {
info!("Converting document");
// Document conversion is not yet implemented
Err(ErrorResponse::not_implemented(
"Document conversion is not yet implemented. This feature requires additional backend services."
))
}
/// Get OCR processing history
#[derive(Deserialize)]
pub struct HistoryQuery {
#[serde(default)]
page: u32,
#[serde(default = "default_limit")]
limit: u32,
}
fn default_limit() -> u32 {
50
}
/// Get OCR processing history
///
/// # Note
/// History storage requires a database backend to be configured.
/// Returns empty results if no database is available.
pub async fn get_ocr_results(
State(_state): State<AppState>,
Query(params): Query<HistoryQuery>,
) -> Result<Json<serde_json::Value>, ErrorResponse> {
info!(
"Getting OCR results history: page={}, limit={}",
params.page, params.limit
);
// History storage not configured - return empty results with notice
Ok(Json(serde_json::json!({
"results": [],
"total": 0,
"page": params.page,
"limit": params.limit,
"notice": "History storage not configured. Results are not persisted."
})))
}
/// Get OCR usage statistics
///
/// # Note
/// Usage tracking requires a database backend to be configured.
/// Returns zeros if no database is available.
pub async fn get_ocr_usage(
State(_state): State<AppState>,
) -> Result<Json<serde_json::Value>, ErrorResponse> {
info!("Getting OCR usage statistics");
// Usage tracking not configured - return zeros with notice
Ok(Json(serde_json::json!({
"requests_today": 0,
"requests_month": 0,
"quota_limit": null,
"quota_remaining": null,
"notice": "Usage tracking not configured. Statistics are not recorded."
})))
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_health_check() {
let response = get_health().await.into_response();
assert_eq!(response.status(), StatusCode::OK);
}
}

View File

@@ -0,0 +1,281 @@
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock};
use uuid::Uuid;
use super::requests::PdfRequest;
/// Job status enumeration
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase")]
pub enum JobStatus {
/// Job is queued but not started
Queued,
/// Job is currently processing
Processing,
/// Job completed successfully
Completed,
/// Job failed with error
Failed,
/// Job was cancelled
Cancelled,
}
/// PDF processing job
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PdfJob {
/// Unique job identifier
pub id: String,
/// Original request
pub request: PdfRequest,
/// Current status
pub status: JobStatus,
/// Creation timestamp
pub created_at: DateTime<Utc>,
/// Last update timestamp
pub updated_at: DateTime<Utc>,
/// Processing result
pub result: Option<String>,
/// Error message (if failed)
pub error: Option<String>,
}
impl PdfJob {
/// Create a new PDF job
pub fn new(request: PdfRequest) -> Self {
let now = Utc::now();
Self {
id: Uuid::new_v4().to_string(),
request,
status: JobStatus::Queued,
created_at: now,
updated_at: now,
result: None,
error: None,
}
}
/// Update job status
pub fn update_status(&mut self, status: JobStatus) {
self.status = status;
self.updated_at = Utc::now();
}
/// Set job result
pub fn set_result(&mut self, result: String) {
self.result = Some(result);
self.status = JobStatus::Completed;
self.updated_at = Utc::now();
}
/// Set job error
pub fn set_error(&mut self, error: String) {
self.error = Some(error);
self.status = JobStatus::Failed;
self.updated_at = Utc::now();
}
}
/// Async job queue with webhook support
pub struct JobQueue {
/// Job storage
jobs: Arc<RwLock<HashMap<String, PdfJob>>>,
/// Job submission channel
tx: mpsc::Sender<PdfJob>,
/// Job processing handle
_handle: Option<tokio::task::JoinHandle<()>>,
}
impl JobQueue {
/// Create a new job queue
pub fn new() -> Self {
Self::with_capacity(1000)
}
/// Create a job queue with specific capacity
pub fn with_capacity(capacity: usize) -> Self {
let jobs = Arc::new(RwLock::new(HashMap::new()));
let (tx, rx) = mpsc::channel(capacity);
let queue_jobs = jobs.clone();
let handle = tokio::spawn(async move {
Self::process_jobs(queue_jobs, rx).await;
});
Self {
jobs,
tx,
_handle: Some(handle),
}
}
/// Enqueue a new job
pub async fn enqueue(&self, mut job: PdfJob) -> anyhow::Result<()> {
job.update_status(JobStatus::Queued);
// Store job
{
let mut jobs = self.jobs.write().await;
jobs.insert(job.id.clone(), job.clone());
}
// Send to processing queue
self.tx.send(job).await?;
Ok(())
}
/// Get job status
pub async fn get_status(&self, id: &str) -> Option<JobStatus> {
let jobs = self.jobs.read().await;
jobs.get(id).map(|job| job.status.clone())
}
/// Get job result
pub async fn get_result(&self, id: &str) -> Option<String> {
let jobs = self.jobs.read().await;
jobs.get(id).and_then(|job| job.result.clone())
}
/// Get job error
pub async fn get_error(&self, id: &str) -> Option<String> {
let jobs = self.jobs.read().await;
jobs.get(id).and_then(|job| job.error.clone())
}
/// Cancel a job
pub async fn cancel(&self, id: &str) -> anyhow::Result<()> {
let mut jobs = self.jobs.write().await;
if let Some(job) = jobs.get_mut(id) {
job.update_status(JobStatus::Cancelled);
Ok(())
} else {
anyhow::bail!("Job not found")
}
}
/// Background job processor
async fn process_jobs(
jobs: Arc<RwLock<HashMap<String, PdfJob>>>,
mut rx: mpsc::Receiver<PdfJob>,
) {
while let Some(job) = rx.recv().await {
let job_id = job.id.clone();
// Update status to processing
{
let mut jobs_lock = jobs.write().await;
if let Some(stored_job) = jobs_lock.get_mut(&job_id) {
stored_job.update_status(JobStatus::Processing);
}
}
// Simulate PDF processing
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
// Update with result
{
let mut jobs_lock = jobs.write().await;
if let Some(stored_job) = jobs_lock.get_mut(&job_id) {
stored_job.set_result("Processed PDF content".to_string());
// Send webhook if specified
if let Some(webhook_url) = &stored_job.request.webhook_url {
Self::send_webhook(webhook_url, stored_job).await;
}
}
}
}
}
/// Send webhook notification
async fn send_webhook(url: &str, job: &PdfJob) {
let client = reqwest::Client::new();
let payload = serde_json::json!({
"job_id": job.id,
"status": job.status,
"result": job.result,
"error": job.error,
});
if let Err(e) = client.post(url).json(&payload).send().await {
tracing::error!("Failed to send webhook to {}: {:?}", url, e);
}
}
}
impl Default for JobQueue {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::api::requests::{PdfOptions, RequestMetadata};
#[tokio::test]
async fn test_job_creation() {
let request = PdfRequest {
url: "https://example.com/test.pdf".to_string(),
options: PdfOptions::default(),
webhook_url: None,
metadata: RequestMetadata::default(),
};
let job = PdfJob::new(request);
assert_eq!(job.status, JobStatus::Queued);
assert!(job.result.is_none());
assert!(job.error.is_none());
}
#[tokio::test]
async fn test_job_queue_enqueue() {
let queue = JobQueue::new();
let request = PdfRequest {
url: "https://example.com/test.pdf".to_string(),
options: PdfOptions::default(),
webhook_url: None,
metadata: RequestMetadata::default(),
};
let job = PdfJob::new(request);
let job_id = job.id.clone();
queue.enqueue(job).await.unwrap();
let status = queue.get_status(&job_id).await;
assert!(status.is_some());
}
#[tokio::test]
async fn test_job_cancellation() {
let queue = JobQueue::new();
let request = PdfRequest {
url: "https://example.com/test.pdf".to_string(),
options: PdfOptions::default(),
webhook_url: None,
metadata: RequestMetadata::default(),
};
let job = PdfJob::new(request);
let job_id = job.id.clone();
queue.enqueue(job).await.unwrap();
queue.cancel(&job_id).await.unwrap();
let status = queue.get_status(&job_id).await;
assert_eq!(status, Some(JobStatus::Cancelled));
}
}

View File

@@ -0,0 +1,197 @@
use axum::{
extract::{Request, State},
http::HeaderMap,
middleware::Next,
response::Response,
};
use governor::{
clock::DefaultClock,
state::{InMemoryState, NotKeyed},
Quota, RateLimiter,
};
use nonzero_ext::nonzero;
use sha2::{Digest, Sha256};
use std::sync::Arc;
use tracing::{debug, warn};
use super::{responses::ErrorResponse, state::AppState};
/// Authentication middleware
/// Validates app_id and app_key from headers or query parameters
pub async fn auth_middleware(
State(state): State<AppState>,
headers: HeaderMap,
request: Request,
next: Next,
) -> Result<Response, ErrorResponse> {
// Check if authentication is enabled
if !state.auth_enabled {
debug!("Authentication disabled, allowing request");
return Ok(next.run(request).await);
}
// Extract credentials from headers
let app_id = headers
.get("app_id")
.and_then(|v| v.to_str().ok())
.or_else(|| {
// Fallback to query parameters
request
.uri()
.query()
.and_then(|q| extract_query_param(q, "app_id"))
});
let app_key = headers
.get("app_key")
.and_then(|v| v.to_str().ok())
.or_else(|| {
request
.uri()
.query()
.and_then(|q| extract_query_param(q, "app_key"))
});
// Validate credentials
match (app_id, app_key) {
(Some(id), Some(key)) => {
if validate_credentials(&state, id, key).await {
debug!("Authentication successful for app_id: {}", id);
Ok(next.run(request).await)
} else {
warn!("Invalid credentials for app_id: {}", id);
Err(ErrorResponse::unauthorized("Invalid credentials"))
}
}
_ => {
warn!("Missing authentication credentials");
Err(ErrorResponse::unauthorized("Missing app_id or app_key"))
}
}
}
/// Rate limiting middleware using token bucket algorithm
pub async fn rate_limit_middleware(
State(state): State<AppState>,
request: Request,
next: Next,
) -> Result<Response, ErrorResponse> {
// Check rate limit
match state.rate_limiter.check() {
Ok(_) => {
debug!("Rate limit check passed");
Ok(next.run(request).await)
}
Err(_) => {
warn!("Rate limit exceeded");
Err(ErrorResponse::rate_limited(
"Rate limit exceeded. Please try again later.",
))
}
}
}
/// Validate app credentials using secure comparison
///
/// SECURITY: This implementation:
/// 1. Requires credentials to be pre-configured in AppState
/// 2. Uses constant-time comparison to prevent timing attacks
/// 3. Hashes the key before comparison
async fn validate_credentials(state: &AppState, app_id: &str, app_key: &str) -> bool {
// Reject empty credentials
if app_id.is_empty() || app_key.is_empty() {
return false;
}
// Get configured credentials from state
let Some(expected_key_hash) = state.api_keys.get(app_id) else {
warn!("Unknown app_id attempted authentication: {}", app_id);
return false;
};
// Hash the provided key
let provided_key_hash = hash_api_key(app_key);
// Constant-time comparison to prevent timing attacks
constant_time_compare(&provided_key_hash, expected_key_hash.as_str())
}
/// Hash an API key using SHA-256
fn hash_api_key(key: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(key.as_bytes());
format!("{:x}", hasher.finalize())
}
/// Constant-time string comparison to prevent timing attacks
fn constant_time_compare(a: &str, b: &str) -> bool {
if a.len() != b.len() {
return false;
}
let mut result = 0u8;
for (x, y) in a.bytes().zip(b.bytes()) {
result |= x ^ y;
}
result == 0
}
/// Extract query parameter from query string
fn extract_query_param<'a>(query: &'a str, param: &str) -> Option<&'a str> {
query.split('&').find_map(|pair| {
let mut parts = pair.split('=');
match (parts.next(), parts.next()) {
(Some(k), Some(v)) if k == param => Some(v),
_ => None,
}
})
}
/// Create a rate limiter with token bucket algorithm
pub fn create_rate_limiter() -> Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>> {
// Allow 100 requests per minute
let quota = Quota::per_minute(nonzero!(100u32));
Arc::new(RateLimiter::direct(quota))
}
/// Type alias for rate limiter
pub type AppRateLimiter = Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_query_param() {
let query = "app_id=123&app_key=secret&foo=bar";
assert_eq!(extract_query_param(query, "app_id"), Some("123"));
assert_eq!(extract_query_param(query, "app_key"), Some("secret"));
assert_eq!(extract_query_param(query, "foo"), Some("bar"));
assert_eq!(extract_query_param(query, "missing"), None);
}
#[test]
fn test_hash_api_key() {
let key = "test_key_123";
let hash1 = hash_api_key(key);
let hash2 = hash_api_key(key);
assert_eq!(hash1, hash2);
assert_ne!(hash_api_key("different"), hash1);
}
#[test]
fn test_constant_time_compare() {
assert!(constant_time_compare("abc", "abc"));
assert!(!constant_time_compare("abc", "abd"));
assert!(!constant_time_compare("abc", "ab"));
assert!(!constant_time_compare("", "a"));
}
#[tokio::test]
async fn test_validate_credentials_rejects_empty() {
let state = AppState::new();
assert!(!validate_credentials(&state, "", "key").await);
assert!(!validate_credentials(&state, "test", "").await);
assert!(!validate_credentials(&state, "", "").await);
}
}

View File

@@ -0,0 +1,91 @@
pub mod handlers;
pub mod jobs;
pub mod middleware;
pub mod requests;
pub mod responses;
pub mod routes;
pub mod state;
use anyhow::Result;
use axum::Router;
use std::net::SocketAddr;
use tokio::signal;
use tracing::{info, warn};
use self::state::AppState;
/// Main API server structure
pub struct ApiServer {
state: AppState,
addr: SocketAddr,
}
impl ApiServer {
/// Create a new API server instance
pub fn new(state: AppState, addr: SocketAddr) -> Self {
Self { state, addr }
}
/// Start the API server with graceful shutdown
pub async fn start(self) -> Result<()> {
let app = self.create_router();
info!("Starting Scipix API server on {}", self.addr);
let listener = tokio::net::TcpListener::bind(self.addr).await?;
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await?;
info!("Server shutdown complete");
Ok(())
}
/// Create the application router with all routes and middleware
fn create_router(&self) -> Router {
routes::router(self.state.clone())
}
}
/// Graceful shutdown signal handler
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {
warn!("Received Ctrl+C, shutting down...");
},
_ = terminate => {
warn!("Received terminate signal, shutting down...");
},
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_server_creation() {
let state = AppState::new();
let addr = "127.0.0.1:3000".parse().unwrap();
let server = ApiServer::new(state, addr);
assert_eq!(server.addr, addr);
}
}

View File

@@ -0,0 +1,227 @@
use serde::{Deserialize, Serialize};
use validator::Validate;
/// Text/Image OCR request
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct TextRequest {
/// Image source (base64, URL, or multipart)
#[serde(skip_serializing_if = "Option::is_none")]
pub src: Option<String>,
/// Base64 encoded image data
#[serde(skip_serializing_if = "Option::is_none")]
pub base64: Option<String>,
/// Image URL
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(url)]
pub url: Option<String>,
/// Request metadata
#[serde(default)]
pub metadata: RequestMetadata,
}
impl TextRequest {
/// Get image data from request
pub async fn get_image_data(&self) -> anyhow::Result<Vec<u8>> {
if let Some(base64_data) = &self.base64 {
// Decode base64
use base64::Engine;
let decoded = base64::engine::general_purpose::STANDARD.decode(base64_data)?;
Ok(decoded)
} else if let Some(url) = &self.url {
// Download from URL
let response = reqwest::get(url).await?;
let bytes = response.bytes().await?;
Ok(bytes.to_vec())
} else {
anyhow::bail!("No image data provided")
}
}
}
/// Digital ink strokes request
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct StrokesRequest {
/// Array of stroke data
#[validate(length(min = 1))]
pub strokes: Vec<Stroke>,
/// Request metadata
#[serde(default)]
pub metadata: RequestMetadata,
}
/// Single stroke in digital ink
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Stroke {
/// X coordinates
pub x: Vec<f64>,
/// Y coordinates
pub y: Vec<f64>,
/// Optional timestamps
#[serde(skip_serializing_if = "Option::is_none")]
pub t: Option<Vec<f64>>,
}
/// Legacy LaTeX equation request
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct LatexRequest {
/// Image source
#[serde(skip_serializing_if = "Option::is_none")]
pub src: Option<String>,
/// Base64 encoded image
#[serde(skip_serializing_if = "Option::is_none")]
pub base64: Option<String>,
/// Image URL
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(url)]
pub url: Option<String>,
/// Request metadata
#[serde(default)]
pub metadata: RequestMetadata,
}
/// PDF processing request
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
pub struct PdfRequest {
/// PDF file URL
#[validate(url)]
pub url: String,
/// Conversion options
#[serde(default)]
pub options: PdfOptions,
/// Webhook URL for completion notification
#[serde(skip_serializing_if = "Option::is_none")]
#[validate(url)]
pub webhook_url: Option<String>,
/// Request metadata
#[serde(default)]
pub metadata: RequestMetadata,
}
/// PDF processing options
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct PdfOptions {
/// Output format
#[serde(default = "default_format")]
pub format: String,
/// Enable OCR
#[serde(default)]
pub enable_ocr: bool,
/// Include images
#[serde(default = "default_true")]
pub include_images: bool,
/// Page range (e.g., "1-5")
#[serde(skip_serializing_if = "Option::is_none")]
pub page_range: Option<String>,
}
fn default_format() -> String {
"mmd".to_string()
}
fn default_true() -> bool {
true
}
/// Request metadata
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct RequestMetadata {
/// Output formats
#[serde(default = "default_formats")]
pub formats: Vec<String>,
/// Include confidence scores
#[serde(default)]
pub include_confidence: bool,
/// Enable math mode
#[serde(default = "default_true")]
pub enable_math: bool,
/// Language hint
#[serde(skip_serializing_if = "Option::is_none")]
pub language: Option<String>,
}
fn default_formats() -> Vec<String> {
vec!["text".to_string()]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_text_request_validation() {
let request = TextRequest {
src: None,
base64: Some("SGVsbG8gV29ybGQ=".to_string()),
url: None,
metadata: RequestMetadata::default(),
};
assert!(request.validate().is_ok());
}
#[test]
fn test_strokes_request_validation() {
let request = StrokesRequest {
strokes: vec![Stroke {
x: vec![0.0, 1.0, 2.0],
y: vec![0.0, 1.0, 0.0],
t: None,
}],
metadata: RequestMetadata::default(),
};
assert!(request.validate().is_ok());
}
#[test]
fn test_empty_strokes_validation() {
let request = StrokesRequest {
strokes: vec![],
metadata: RequestMetadata::default(),
};
assert!(request.validate().is_err());
}
#[test]
fn test_pdf_request_validation() {
let request = PdfRequest {
url: "https://example.com/document.pdf".to_string(),
options: PdfOptions::default(),
webhook_url: None,
metadata: RequestMetadata::default(),
};
assert!(request.validate().is_ok());
}
#[test]
fn test_invalid_url() {
let request = PdfRequest {
url: "not-a-url".to_string(),
options: PdfOptions::default(),
webhook_url: None,
metadata: RequestMetadata::default(),
};
assert!(request.validate().is_err());
}
}

View File

@@ -0,0 +1,177 @@
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
Json,
};
use serde::{Deserialize, Serialize};
use super::jobs::JobStatus;
/// Standard text/OCR response
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextResponse {
/// Unique request identifier
pub request_id: String,
/// Recognized text
pub text: String,
/// Confidence score (0.0 - 1.0)
pub confidence: f64,
/// LaTeX output (if requested)
#[serde(skip_serializing_if = "Option::is_none")]
pub latex: Option<String>,
/// MathML output (if requested)
#[serde(skip_serializing_if = "Option::is_none")]
pub mathml: Option<String>,
/// HTML output (if requested)
#[serde(skip_serializing_if = "Option::is_none")]
pub html: Option<String>,
}
/// PDF processing response
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PdfResponse {
/// PDF job identifier
pub pdf_id: String,
/// Current job status
pub status: JobStatus,
/// Status message
#[serde(skip_serializing_if = "Option::is_none")]
pub message: Option<String>,
/// Processing result (when completed)
#[serde(skip_serializing_if = "Option::is_none")]
pub result: Option<String>,
/// Error details (if failed)
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
}
/// Error response
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorResponse {
/// Error code
pub error_code: String,
/// Human-readable error message
pub message: String,
/// HTTP status code
#[serde(skip)]
pub status: StatusCode,
}
impl ErrorResponse {
/// Create a validation error response
pub fn validation_error(message: impl Into<String>) -> Self {
Self {
error_code: "VALIDATION_ERROR".to_string(),
message: message.into(),
status: StatusCode::BAD_REQUEST,
}
}
/// Create an unauthorized error response
pub fn unauthorized(message: impl Into<String>) -> Self {
Self {
error_code: "UNAUTHORIZED".to_string(),
message: message.into(),
status: StatusCode::UNAUTHORIZED,
}
}
/// Create a not found error response
pub fn not_found(message: impl Into<String>) -> Self {
Self {
error_code: "NOT_FOUND".to_string(),
message: message.into(),
status: StatusCode::NOT_FOUND,
}
}
/// Create a rate limit error response
pub fn rate_limited(message: impl Into<String>) -> Self {
Self {
error_code: "RATE_LIMIT_EXCEEDED".to_string(),
message: message.into(),
status: StatusCode::TOO_MANY_REQUESTS,
}
}
/// Create an internal error response
pub fn internal_error(message: impl Into<String>) -> Self {
Self {
error_code: "INTERNAL_ERROR".to_string(),
message: message.into(),
status: StatusCode::INTERNAL_SERVER_ERROR,
}
}
/// Create a service unavailable error response
/// Used when the service is not fully configured (e.g., missing models)
pub fn service_unavailable(message: impl Into<String>) -> Self {
Self {
error_code: "SERVICE_UNAVAILABLE".to_string(),
message: message.into(),
status: StatusCode::SERVICE_UNAVAILABLE,
}
}
/// Create a not implemented error response
pub fn not_implemented(message: impl Into<String>) -> Self {
Self {
error_code: "NOT_IMPLEMENTED".to_string(),
message: message.into(),
status: StatusCode::NOT_IMPLEMENTED,
}
}
}
impl IntoResponse for ErrorResponse {
fn into_response(self) -> Response {
let status = self.status;
(status, Json(self)).into_response()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_text_response_serialization() {
let response = TextResponse {
request_id: "test-123".to_string(),
text: "Hello World".to_string(),
confidence: 0.95,
latex: Some("x^2".to_string()),
mathml: None,
html: None,
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("request_id"));
assert!(json.contains("test-123"));
assert!(!json.contains("mathml"));
}
#[test]
fn test_error_response_creation() {
let error = ErrorResponse::validation_error("Invalid input");
assert_eq!(error.status, StatusCode::BAD_REQUEST);
assert_eq!(error.error_code, "VALIDATION_ERROR");
let error = ErrorResponse::unauthorized("Invalid credentials");
assert_eq!(error.status, StatusCode::UNAUTHORIZED);
let error = ErrorResponse::rate_limited("Too many requests");
assert_eq!(error.status, StatusCode::TOO_MANY_REQUESTS);
}
}

View File

@@ -0,0 +1,103 @@
use axum::{
routing::{delete, get, post},
Router,
};
use tower::ServiceBuilder;
use tower_http::{
compression::CompressionLayer,
cors::CorsLayer,
trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer},
};
use tracing::Level;
use super::{
handlers::{
convert_document, delete_pdf_job, get_health, get_ocr_results, get_ocr_usage,
get_pdf_status, process_latex, process_pdf, process_strokes, process_text,
stream_pdf_results,
},
middleware::{auth_middleware, rate_limit_middleware},
state::AppState,
};
/// Create the main application router with all routes and middleware
pub fn router(state: AppState) -> Router {
// API v3 routes
let api_routes = Router::new()
// Image processing
.route("/v3/text", post(process_text))
// Digital ink processing
.route("/v3/strokes", post(process_strokes))
// Legacy equation processing
.route("/v3/latex", post(process_latex))
// Async PDF processing
.route("/v3/pdf", post(process_pdf))
.route("/v3/pdf/:id", get(get_pdf_status))
.route("/v3/pdf/:id", delete(delete_pdf_job))
.route("/v3/pdf/:id/stream", get(stream_pdf_results))
// Document conversion
.route("/v3/converter", post(convert_document))
// History and usage
.route("/v3/ocr-results", get(get_ocr_results))
.route("/v3/ocr-usage", get(get_ocr_usage))
// Apply auth and rate limiting to all API routes
.layer(
ServiceBuilder::new()
.layer(axum::middleware::from_fn_with_state(
state.clone(),
auth_middleware,
))
.layer(axum::middleware::from_fn_with_state(
state.clone(),
rate_limit_middleware,
)),
);
// Health check (no auth required)
let health_routes = Router::new().route("/health", get(get_health));
// Combine all routes
Router::new()
.merge(api_routes)
.merge(health_routes)
.layer(
ServiceBuilder::new()
// Tracing layer
.layer(
TraceLayer::new_for_http()
.make_span_with(DefaultMakeSpan::new().level(Level::INFO))
.on_response(DefaultOnResponse::new().level(Level::INFO)),
)
// CORS layer
.layer(CorsLayer::permissive())
// Compression layer
.layer(CompressionLayer::new()),
)
.with_state(state)
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use tower::ServiceExt;
#[tokio::test]
async fn test_health_endpoint() {
let state = AppState::new();
let app = router(state);
let response = app
.oneshot(
Request::builder()
.uri("/health")
.body(Body::empty())
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
}

View File

@@ -0,0 +1,148 @@
use moka::future::Cache;
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use super::{
jobs::JobQueue,
middleware::{create_rate_limiter, AppRateLimiter},
};
/// Shared application state
#[derive(Clone)]
pub struct AppState {
/// Job queue for async PDF processing
pub job_queue: Arc<JobQueue>,
/// Result cache
pub cache: Cache<String, String>,
/// Rate limiter
pub rate_limiter: AppRateLimiter,
/// Whether authentication is enabled
pub auth_enabled: bool,
/// Map of app_id -> hashed API key
/// Keys should be stored as SHA-256 hashes, never in plaintext
pub api_keys: Arc<HashMap<String, String>>,
}
impl AppState {
/// Create a new application state instance with authentication disabled
pub fn new() -> Self {
Self {
job_queue: Arc::new(JobQueue::new()),
cache: create_cache(),
rate_limiter: create_rate_limiter(),
auth_enabled: false,
api_keys: Arc::new(HashMap::new()),
}
}
/// Create state with custom configuration
pub fn with_config(max_jobs: usize, cache_size: u64) -> Self {
Self {
job_queue: Arc::new(JobQueue::with_capacity(max_jobs)),
cache: Cache::builder()
.max_capacity(cache_size)
.time_to_live(Duration::from_secs(3600))
.time_to_idle(Duration::from_secs(600))
.build(),
rate_limiter: create_rate_limiter(),
auth_enabled: false,
api_keys: Arc::new(HashMap::new()),
}
}
/// Create state with authentication enabled
pub fn with_auth(api_keys: HashMap<String, String>) -> Self {
// Hash all provided API keys
let hashed_keys: HashMap<String, String> = api_keys
.into_iter()
.map(|(app_id, key)| (app_id, hash_api_key(&key)))
.collect();
Self {
job_queue: Arc::new(JobQueue::new()),
cache: create_cache(),
rate_limiter: create_rate_limiter(),
auth_enabled: true,
api_keys: Arc::new(hashed_keys),
}
}
/// Add an API key (hashes the key before storing)
pub fn add_api_key(&mut self, app_id: String, api_key: &str) {
let hashed = hash_api_key(api_key);
Arc::make_mut(&mut self.api_keys).insert(app_id, hashed);
self.auth_enabled = true;
}
/// Enable or disable authentication
pub fn set_auth_enabled(&mut self, enabled: bool) {
self.auth_enabled = enabled;
}
}
impl Default for AppState {
fn default() -> Self {
Self::new()
}
}
/// Hash an API key using SHA-256
fn hash_api_key(key: &str) -> String {
let mut hasher = Sha256::new();
hasher.update(key.as_bytes());
format!("{:x}", hasher.finalize())
}
/// Create a cache with default configuration
fn create_cache() -> Cache<String, String> {
Cache::builder()
// Max 10,000 entries
.max_capacity(10_000)
// Time to live: 1 hour
.time_to_live(Duration::from_secs(3600))
// Time to idle: 10 minutes
.time_to_idle(Duration::from_secs(600))
.build()
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_state_creation() {
let state = AppState::new();
assert!(Arc::strong_count(&state.job_queue) >= 1);
}
#[tokio::test]
async fn test_state_with_config() {
let state = AppState::with_config(100, 5000);
assert!(Arc::strong_count(&state.job_queue) >= 1);
}
#[tokio::test]
async fn test_cache_operations() {
let state = AppState::new();
// Insert value
state
.cache
.insert("key1".to_string(), "value1".to_string())
.await;
// Retrieve value
let value = state.cache.get(&"key1".to_string()).await;
assert_eq!(value, Some("value1".to_string()));
// Non-existent key
let missing = state.cache.get(&"missing".to_string()).await;
assert_eq!(missing, None);
}
}

View File

@@ -0,0 +1,763 @@
//! SciPix OCR Benchmark Tool
//!
//! Comprehensive benchmark for OCR performance including:
//! - Image preprocessing speed
//! - Text detection throughput
//! - Character recognition latency
//! - End-to-end pipeline benchmarks
use image::{DynamicImage, ImageBuffer, Luma, Rgb, RgbImage};
use imageproc::contrast::ThresholdType;
use imageproc::drawing::draw_filled_rect_mut;
use imageproc::rect::Rect;
use std::fs;
use std::path::PathBuf;
use std::time::{Duration, Instant};
// Import SIMD optimizations
use ruvector_scipix::optimize::simd::{
fast_area_resize, simd_grayscale, simd_resize_bilinear, simd_threshold,
};
/// Benchmark results
#[derive(Debug, Clone)]
struct BenchmarkResult {
name: String,
iterations: usize,
total_time: Duration,
avg_time: Duration,
min_time: Duration,
max_time: Duration,
throughput: f64,
}
impl BenchmarkResult {
fn display(&self) {
println!("\n{}", "=".repeat(60));
println!("Benchmark: {}", self.name);
println!("{}", "=".repeat(60));
println!(" Iterations: {}", self.iterations);
println!(" Total time: {:?}", self.total_time);
println!(" Avg time: {:?}", self.avg_time);
println!(" Min time: {:?}", self.min_time);
println!(" Max time: {:?}", self.max_time);
println!(" Throughput: {:.2} ops/sec", self.throughput);
}
}
/// Generate a test image with synthetic patterns (simulating text)
fn generate_test_image(width: u32, height: u32) -> RgbImage {
let mut img: RgbImage = ImageBuffer::from_fn(width, height, |_, _| {
Rgb([255u8, 255u8, 255u8]) // White background
});
// Draw black rectangles to simulate text blocks
for i in 0..10 {
let x = (i * 35 + 10) as i32;
let y = 20;
draw_filled_rect_mut(
&mut img,
Rect::at(x, y).of_size(25, 40),
Rgb([0u8, 0u8, 0u8]),
);
}
// Draw a horizontal line (like an equation fraction)
draw_filled_rect_mut(
&mut img,
Rect::at(10, 70).of_size(350, 2),
Rgb([0u8, 0u8, 0u8]),
);
img
}
/// Generate a math-like test image
fn generate_math_image(width: u32, height: u32) -> RgbImage {
let mut img: RgbImage = ImageBuffer::from_fn(width, height, |_, _| Rgb([255u8, 255u8, 255u8]));
// Draw elements resembling a fraction
draw_filled_rect_mut(
&mut img,
Rect::at(50, 20).of_size(100, 30),
Rgb([0u8, 0u8, 0u8]),
);
draw_filled_rect_mut(
&mut img,
Rect::at(20, 60).of_size(160, 3),
Rgb([0u8, 0u8, 0u8]),
);
draw_filled_rect_mut(
&mut img,
Rect::at(70, 70).of_size(60, 30),
Rgb([0u8, 0u8, 0u8]),
);
// Draw square root symbol approximation
draw_filled_rect_mut(
&mut img,
Rect::at(200, 30).of_size(5, 40),
Rgb([0u8, 0u8, 0u8]),
);
draw_filled_rect_mut(
&mut img,
Rect::at(200, 30).of_size(80, 3),
Rgb([0u8, 0u8, 0u8]),
);
img
}
/// Run a benchmark function multiple times and collect statistics
fn run_benchmark<F, E>(name: &str, iterations: usize, mut f: F) -> BenchmarkResult
where
F: FnMut() -> Result<(), E>,
E: std::fmt::Debug,
{
let mut times = Vec::with_capacity(iterations);
// Warmup
for _ in 0..3 {
let _ = f();
}
// Actual benchmark
for _ in 0..iterations {
let start = Instant::now();
let _ = f();
times.push(start.elapsed());
}
let total_time: Duration = times.iter().sum();
let avg_time = total_time / iterations as u32;
let min_time = *times.iter().min().unwrap();
let max_time = *times.iter().max().unwrap();
let throughput = iterations as f64 / total_time.as_secs_f64();
BenchmarkResult {
name: name.to_string(),
iterations,
total_time,
avg_time,
min_time,
max_time,
throughput,
}
}
/// Benchmark grayscale conversion
fn benchmark_grayscale(images: &[DynamicImage]) -> BenchmarkResult {
let mut idx = 0;
run_benchmark::<_, std::convert::Infallible>("Grayscale Conversion", 500, || {
let img = &images[idx % images.len()];
idx += 1;
let _gray = img.to_luma8();
Ok(())
})
}
/// Benchmark image resize
fn benchmark_resize(images: &[DynamicImage]) -> BenchmarkResult {
use image::imageops::FilterType;
let mut idx = 0;
run_benchmark::<_, std::convert::Infallible>("Image Resize (640x480)", 100, || {
let img = &images[idx % images.len()];
idx += 1;
let _resized = img.resize(640, 480, FilterType::Lanczos3);
Ok(())
})
}
/// Benchmark fast resize
fn benchmark_fast_resize(images: &[DynamicImage]) -> BenchmarkResult {
use image::imageops::FilterType;
let mut idx = 0;
run_benchmark::<_, std::convert::Infallible>("Fast Resize (Nearest)", 500, || {
let img = &images[idx % images.len()];
idx += 1;
let _resized = img.resize(640, 480, FilterType::Nearest);
Ok(())
})
}
/// Benchmark Gaussian blur
fn benchmark_blur(images: &[DynamicImage]) -> BenchmarkResult {
let mut idx = 0;
run_benchmark::<_, std::convert::Infallible>("Gaussian Blur (σ=1.5)", 50, || {
let img = &images[idx % images.len()];
idx += 1;
let gray = img.to_luma8();
let _blurred = imageproc::filter::gaussian_blur_f32(&gray, 1.5);
Ok(())
})
}
/// Benchmark threshold (binarization)
fn benchmark_threshold(images: &[DynamicImage]) -> BenchmarkResult {
let mut idx = 0;
run_benchmark::<_, std::convert::Infallible>("Otsu Threshold", 100, || {
let img = &images[idx % images.len()];
idx += 1;
let gray = img.to_luma8();
let _thresholded = imageproc::contrast::threshold(&gray, 128, ThresholdType::Binary);
Ok(())
})
}
/// Benchmark adaptive threshold
fn benchmark_adaptive_threshold(images: &[DynamicImage]) -> BenchmarkResult {
let mut idx = 0;
run_benchmark::<_, std::convert::Infallible>("Adaptive Threshold", 30, || {
let img = &images[idx % images.len()];
idx += 1;
let gray = img.to_luma8();
let _thresholded = imageproc::contrast::adaptive_threshold(&gray, 11);
Ok(())
})
}
/// Benchmark memory throughput
fn benchmark_memory_throughput() -> BenchmarkResult {
let data: Vec<f32> = (0..1_000_000).map(|i| i as f32).collect();
run_benchmark::<_, std::convert::Infallible>("Memory Throughput (1M floats)", 100, || {
let _sum: f32 = data.iter().sum();
let _clone = data.clone();
Ok(())
})
}
/// Benchmark tensor creation for ONNX
fn benchmark_tensor_creation() -> BenchmarkResult {
use ndarray::Array4;
run_benchmark::<_, ndarray::ShapeError>("Tensor Creation (1x3x224x224)", 100, || {
let tensor_data: Vec<f32> = vec![0.0; 1 * 3 * 224 * 224];
let _tensor = Array4::from_shape_vec((1, 3, 224, 224), tensor_data)?;
Ok(())
})
}
/// Benchmark large tensor creation
fn benchmark_large_tensor() -> BenchmarkResult {
use ndarray::Array4;
run_benchmark::<_, ndarray::ShapeError>("Large Tensor (1x3x640x480)", 50, || {
let tensor_data: Vec<f32> = vec![0.0; 1 * 3 * 640 * 480];
let _tensor = Array4::from_shape_vec((1, 3, 640, 480), tensor_data)?;
Ok(())
})
}
/// Benchmark image normalization
fn benchmark_normalization(images: &[DynamicImage]) -> BenchmarkResult {
let mut idx = 0;
run_benchmark::<_, std::convert::Infallible>("Image Normalization", 200, || {
let img = &images[idx % images.len()];
idx += 1;
let rgb = img.to_rgb8();
let mut tensor = Vec::with_capacity(3 * rgb.width() as usize * rgb.height() as usize);
// NCHW format normalization
for c in 0..3 {
for y in 0..rgb.height() {
for x in 0..rgb.width() {
let pixel = rgb.get_pixel(x, y);
tensor.push((pixel[c] as f32 / 127.5) - 1.0);
}
}
}
Ok(())
})
}
/// Benchmark image loading from disk
fn benchmark_image_load(path: &PathBuf) -> BenchmarkResult {
run_benchmark::<_, image::ImageError>("Image Load from Disk", 100, || {
let _img = image::open(path)?;
Ok(())
})
}
/// Benchmark edge detection
fn benchmark_edge_detection(images: &[DynamicImage]) -> BenchmarkResult {
let mut idx = 0;
run_benchmark::<_, std::convert::Infallible>("Sobel Edge Detection", 50, || {
let img = &images[idx % images.len()];
idx += 1;
let gray = img.to_luma8();
let _edges = imageproc::gradients::sobel_gradients(&gray);
Ok(())
})
}
/// Benchmark connected components
fn benchmark_connected_components(images: &[DynamicImage]) -> BenchmarkResult {
let mut idx = 0;
run_benchmark::<_, std::convert::Infallible>("Connected Components", 50, || {
let img = &images[idx % images.len()];
idx += 1;
let gray = img.to_luma8();
let binary = imageproc::contrast::threshold(&gray, 128, ThresholdType::Binary);
let _cc = imageproc::region_labelling::connected_components(
&binary,
imageproc::region_labelling::Connectivity::Eight,
Luma([0u8]),
);
Ok(())
})
}
/// Benchmark SIMD grayscale conversion
fn benchmark_simd_grayscale(images: &[DynamicImage]) -> BenchmarkResult {
let mut idx = 0;
run_benchmark::<_, std::convert::Infallible>("SIMD Grayscale", 500, || {
let img = &images[idx % images.len()];
idx += 1;
let rgba = img.to_rgba8();
let mut gray = vec![0u8; (rgba.width() * rgba.height()) as usize];
simd_grayscale(rgba.as_raw(), &mut gray);
Ok(())
})
}
/// Benchmark SIMD bilinear resize
fn benchmark_simd_resize(images: &[DynamicImage]) -> BenchmarkResult {
let mut idx = 0;
run_benchmark::<_, std::convert::Infallible>("SIMD Resize (Bilinear)", 500, || {
let img = &images[idx % images.len()];
idx += 1;
let gray = img.to_luma8();
let _resized = simd_resize_bilinear(
gray.as_raw(),
gray.width() as usize,
gray.height() as usize,
640,
480,
);
Ok(())
})
}
/// Benchmark fast area resize
fn benchmark_area_resize(images: &[DynamicImage]) -> BenchmarkResult {
let mut idx = 0;
run_benchmark::<_, std::convert::Infallible>("Fast Area Resize", 500, || {
let img = &images[idx % images.len()];
idx += 1;
let gray = img.to_luma8();
let _resized = fast_area_resize(
gray.as_raw(),
gray.width() as usize,
gray.height() as usize,
640,
480,
);
Ok(())
})
}
/// Benchmark SIMD threshold
fn benchmark_simd_threshold(images: &[DynamicImage]) -> BenchmarkResult {
let mut idx = 0;
run_benchmark::<_, std::convert::Infallible>("SIMD Threshold", 500, || {
let img = &images[idx % images.len()];
idx += 1;
let gray = img.to_luma8();
let mut out = vec![0u8; gray.as_raw().len()];
simd_threshold(gray.as_raw(), 128, &mut out);
Ok(())
})
}
/// Complete preprocessing pipeline benchmark (SIMD optimized)
fn benchmark_simd_pipeline(images: &[DynamicImage]) -> BenchmarkResult {
let mut idx = 0;
run_benchmark::<_, std::convert::Infallible>("SIMD Full Pipeline", 200, || {
let img = &images[idx % images.len()];
idx += 1;
// Step 1: RGBA to Grayscale
let rgba = img.to_rgba8();
let mut gray = vec![0u8; (rgba.width() * rgba.height()) as usize];
simd_grayscale(rgba.as_raw(), &mut gray);
// Step 2: Resize
let resized = simd_resize_bilinear(
&gray,
rgba.width() as usize,
rgba.height() as usize,
224,
224,
);
// Step 3: Threshold
let mut binary = vec![0u8; resized.len()];
simd_threshold(&resized, 128, &mut binary);
// Step 4: Normalize to tensor format
let _tensor: Vec<f32> = binary.iter().map(|&x| (x as f32 / 127.5) - 1.0).collect();
Ok(())
})
}
/// Original preprocessing pipeline benchmark (for comparison)
fn benchmark_original_pipeline(images: &[DynamicImage]) -> BenchmarkResult {
let mut idx = 0;
run_benchmark::<_, std::convert::Infallible>("Original Full Pipeline", 200, || {
let img = &images[idx % images.len()];
idx += 1;
// Step 1: Grayscale
let gray = img.to_luma8();
// Step 2: Resize
let resized =
image::imageops::resize(&gray, 224, 224, image::imageops::FilterType::Nearest);
// Step 3: Threshold
let binary = imageproc::contrast::threshold(&resized, 128, ThresholdType::Binary);
// Step 4: Normalize
let _tensor: Vec<f32> = binary
.as_raw()
.iter()
.map(|&x| (x as f32 / 127.5) - 1.0)
.collect();
Ok(())
})
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("\n{}", "=".repeat(60));
println!(" SciPix OCR Benchmark Suite");
println!("{}", "=".repeat(60));
println!("\nGenerating test images...");
// Generate test images
let text_image = generate_test_image(400, 100);
let math_image = generate_math_image(300, 150);
let large_image = generate_test_image(800, 200);
let hd_image = generate_test_image(1920, 1080);
// Save test images
let test_dir = PathBuf::from("test_images");
fs::create_dir_all(&test_dir)?;
text_image.save(test_dir.join("text_test.png"))?;
math_image.save(test_dir.join("math_test.png"))?;
large_image.save(test_dir.join("large_test.png"))?;
hd_image.save(test_dir.join("hd_test.png"))?;
println!("Test images saved to test_images/\n");
// Convert to DynamicImage for benchmarks
let images: Vec<DynamicImage> = vec![
DynamicImage::ImageRgb8(text_image.clone()),
DynamicImage::ImageRgb8(math_image.clone()),
DynamicImage::ImageRgb8(large_image.clone()),
];
let hd_images = vec![DynamicImage::ImageRgb8(hd_image.clone())];
// Run benchmarks
let mut results = Vec::new();
println!("Running image conversion benchmarks...");
results.push(benchmark_grayscale(&images));
println!("Running resize benchmarks...");
results.push(benchmark_resize(&images));
results.push(benchmark_fast_resize(&images));
println!("Running filter benchmarks...");
results.push(benchmark_blur(&images));
results.push(benchmark_threshold(&images));
results.push(benchmark_adaptive_threshold(&images));
results.push(benchmark_edge_detection(&images));
results.push(benchmark_connected_components(&images));
println!("Running SIMD optimized benchmarks...");
results.push(benchmark_simd_grayscale(&images));
results.push(benchmark_simd_resize(&images));
results.push(benchmark_area_resize(&images));
results.push(benchmark_simd_threshold(&images));
println!("Running pipeline benchmarks...");
results.push(benchmark_original_pipeline(&images));
results.push(benchmark_simd_pipeline(&images));
println!("Running normalization benchmarks...");
results.push(benchmark_normalization(&images));
println!("Running memory benchmarks...");
results.push(benchmark_memory_throughput());
results.push(benchmark_tensor_creation());
results.push(benchmark_large_tensor());
println!("Running I/O benchmarks...");
results.push(benchmark_image_load(&test_dir.join("text_test.png")));
println!("\nRunning HD image benchmarks...");
results.push(run_benchmark::<_, std::convert::Infallible>(
"HD Grayscale (1920x1080)",
100,
|| {
let _gray = hd_images[0].to_luma8();
Ok(())
},
));
results.push(run_benchmark::<_, std::convert::Infallible>(
"HD Resize to 640x480",
50,
|| {
let _resized = hd_images[0].resize(640, 480, image::imageops::FilterType::Lanczos3);
Ok(())
},
));
// Display results
println!("\n\n{}", "#".repeat(60));
println!(" BENCHMARK RESULTS");
println!("{}", "#".repeat(60));
for result in &results {
result.display();
}
// Summary table
println!("\n\n{}", "=".repeat(75));
println!("{:45} {:>15} {:>15}", "Benchmark", "Avg Time", "Throughput");
println!("{}", "-".repeat(75));
for result in &results {
println!(
"{:45} {:>15.2?} {:>12.2} ops/s",
result.name, result.avg_time, result.throughput
);
}
println!("{}", "=".repeat(75));
// Performance analysis
println!("\n{}", "=".repeat(60));
println!(" PERFORMANCE ANALYSIS");
println!("{}", "=".repeat(60));
// Calculate total preprocessing time for a typical pipeline
let grayscale_time = results
.iter()
.find(|r| r.name == "Grayscale Conversion")
.map(|r| r.avg_time)
.unwrap_or_default();
let resize_time = results
.iter()
.find(|r| r.name == "Fast Resize (Nearest)")
.map(|r| r.avg_time)
.unwrap_or_default();
let threshold_time = results
.iter()
.find(|r| r.name == "Otsu Threshold")
.map(|r| r.avg_time)
.unwrap_or_default();
let normalize_time = results
.iter()
.find(|r| r.name == "Image Normalization")
.map(|r| r.avg_time)
.unwrap_or_default();
let total_preprocess = grayscale_time + resize_time + threshold_time + normalize_time;
// SIMD optimized times
let simd_grayscale = results
.iter()
.find(|r| r.name == "SIMD Grayscale")
.map(|r| r.avg_time)
.unwrap_or_default();
let simd_resize = results
.iter()
.find(|r| r.name == "SIMD Resize (Bilinear)")
.map(|r| r.avg_time)
.unwrap_or_default();
let simd_threshold = results
.iter()
.find(|r| r.name == "SIMD Threshold")
.map(|r| r.avg_time)
.unwrap_or_default();
let original_pipeline = results
.iter()
.find(|r| r.name == "Original Full Pipeline")
.map(|r| r.avg_time)
.unwrap_or_default();
let simd_pipeline = results
.iter()
.find(|r| r.name == "SIMD Full Pipeline")
.map(|r| r.avg_time)
.unwrap_or_default();
println!("\n┌──────────────────────────────────────────────────────────────────┐");
println!("│ SIMD Optimization Comparison │");
println!("├────────────────────┬──────────────┬──────────────┬───────────────┤");
println!("│ Operation │ Original │ SIMD │ Speedup │");
println!("├────────────────────┼──────────────┼──────────────┼───────────────┤");
println!(
"│ Grayscale │ {:>10.2?}{:>10.2?}{:>6.2}x │",
grayscale_time,
simd_grayscale,
if simd_grayscale.as_nanos() > 0 {
grayscale_time.as_secs_f64() / simd_grayscale.as_secs_f64()
} else {
1.0
}
);
println!(
"│ Resize │ {:>10.2?}{:>10.2?}{:>6.2}x │",
resize_time,
simd_resize,
if simd_resize.as_nanos() > 0 {
resize_time.as_secs_f64() / simd_resize.as_secs_f64()
} else {
1.0
}
);
println!(
"│ Threshold │ {:>10.2?}{:>10.2?}{:>6.2}x │",
threshold_time,
simd_threshold,
if simd_threshold.as_nanos() > 0 {
threshold_time.as_secs_f64() / simd_threshold.as_secs_f64()
} else {
1.0
}
);
println!("├────────────────────┼──────────────┼──────────────┼───────────────┤");
println!(
"│ Full Pipeline │ {:>10.2?}{:>10.2?}{:>6.2}x │",
original_pipeline,
simd_pipeline,
if simd_pipeline.as_nanos() > 0 {
original_pipeline.as_secs_f64() / simd_pipeline.as_secs_f64()
} else {
1.0
}
);
println!("└────────────────────┴──────────────┴──────────────┴───────────────┘");
println!("\n┌──────────────────────────────────────────────────┐");
println!("│ Typical Preprocessing Pipeline Breakdown │");
println!("├──────────────────────────────────────────────────┤");
println!(
"│ Grayscale: {:>10.2?} ({:.1}%) │",
grayscale_time,
100.0 * grayscale_time.as_secs_f64() / total_preprocess.as_secs_f64()
);
println!(
"│ Resize: {:>10.2?} ({:.1}%) │",
resize_time,
100.0 * resize_time.as_secs_f64() / total_preprocess.as_secs_f64()
);
println!(
"│ Threshold: {:>10.2?} ({:.1}%) │",
threshold_time,
100.0 * threshold_time.as_secs_f64() / total_preprocess.as_secs_f64()
);
println!(
"│ Normalization: {:>10.2?} ({:.1}%) │",
normalize_time,
100.0 * normalize_time.as_secs_f64() / total_preprocess.as_secs_f64()
);
println!("├──────────────────────────────────────────────────┤");
println!(
"│ TOTAL: {:>10.2?}",
total_preprocess
);
println!("└──────────────────────────────────────────────────┘");
println!("\nTarget latency for real-time (30 fps): 33.3ms");
if total_preprocess.as_millis() < 33 {
println!(
"✓ Preprocessing meets real-time requirements ({:.1}ms < 33.3ms)",
total_preprocess.as_secs_f64() * 1000.0
);
} else {
println!(
"⚠ Preprocessing exceeds real-time target ({:.1}ms > 33.3ms)",
total_preprocess.as_secs_f64() * 1000.0
);
}
// Memory efficiency
let tensor_throughput = results
.iter()
.find(|r| r.name.contains("Tensor Creation"))
.map(|r| r.throughput)
.unwrap_or(0.0);
println!(
"\nTensor creation throughput: {:.0} tensors/sec",
tensor_throughput
);
println!("Target for batch inference: >100 tensors/sec");
if tensor_throughput > 100.0 {
println!("✓ Tensor creation meets batch requirements");
} else {
println!("⚠ Consider tensor pooling optimization");
}
// Estimated end-to-end throughput
let estimated_ocr_time = total_preprocess.as_secs_f64() * 1000.0 + 50.0; // preprocessing + estimated inference
let estimated_throughput = 1000.0 / estimated_ocr_time;
println!("\n┌──────────────────────────────────────────────────┐");
println!("│ Estimated End-to-End Performance │");
println!("├──────────────────────────────────────────────────┤");
println!(
"│ Preprocessing: {:>8.2}ms │",
total_preprocess.as_secs_f64() * 1000.0
);
println!("│ Est. Inference: {:>8.2}ms (target) │", 50.0);
println!(
"│ Total latency: {:>8.2}ms │",
estimated_ocr_time
);
println!(
"│ Throughput: {:>8.1} images/sec │",
estimated_throughput
);
println!("└──────────────────────────────────────────────────┘");
// State of the art comparison
println!("\n{}", "=".repeat(60));
println!(" STATE OF THE ART COMPARISON");
println!("{}", "=".repeat(60));
println!("\n┌────────────────────────────────────────────────────────┐");
println!("│ System │ Latency │ Throughput │ Status │");
println!("├────────────────────────────────────────────────────────┤");
println!("│ Tesseract │ ~200ms │ ~5 img/s │ Slow │");
println!("│ PaddleOCR │ ~50ms │ ~20 img/s │ Fast │");
println!("│ EasyOCR │ ~100ms │ ~10 img/s │ Medium │");
println!(
"│ SciPix (est.) │ {:>6.1}ms │ {:>6.1} img/s │ {}",
estimated_ocr_time,
estimated_throughput,
if estimated_throughput > 15.0 {
"Fast "
} else if estimated_throughput > 8.0 {
"Medium "
} else {
"Slow "
}
);
println!("└────────────────────────────────────────────────────────┘");
println!("\n{}", "=".repeat(60));
println!("Benchmark complete!");
println!("{}", "=".repeat(60));
Ok(())
}

View File

@@ -0,0 +1,66 @@
use anyhow::Result;
use clap::Parser;
use ruvector_scipix::cli::{Cli, Commands};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
#[tokio::main]
async fn main() -> Result<()> {
let cli = Cli::parse();
// Initialize logging based on verbosity
let log_level = if cli.quiet {
tracing::Level::ERROR
} else if cli.verbose {
tracing::Level::DEBUG
} else {
tracing::Level::INFO
};
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| format!("{}={}", env!("CARGO_PKG_NAME"), log_level).into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
// Execute the command
match &cli.command {
Commands::Ocr(args) => {
ruvector_scipix::cli::commands::ocr::execute(args.clone(), &cli).await?;
}
Commands::Batch(args) => {
ruvector_scipix::cli::commands::batch::execute(args.clone(), &cli).await?;
}
Commands::Serve(args) => {
ruvector_scipix::cli::commands::serve::execute(args.clone(), &cli).await?;
}
Commands::Mcp(args) => {
ruvector_scipix::cli::commands::mcp::run(args.clone()).await?;
}
Commands::Config(args) => {
ruvector_scipix::cli::commands::config::execute(args.clone(), &cli).await?;
}
Commands::Doctor(args) => {
ruvector_scipix::cli::commands::doctor::execute(args.clone()).await?;
}
Commands::Version => {
println!("scipix-cli v{}", env!("CARGO_PKG_VERSION"));
println!("A Rust-based CLI for Scipix OCR processing");
}
Commands::Completions { shell } => {
use clap::CommandFactory;
use clap_complete::{generate, Shell};
let shell = shell
.clone()
.unwrap_or_else(|| Shell::from_env().unwrap_or(Shell::Bash));
let mut cmd = Cli::command();
let bin_name = cmd.get_name().to_string();
generate(shell, &mut cmd, bin_name, &mut std::io::stdout());
}
}
Ok(())
}

View File

@@ -0,0 +1,37 @@
use anyhow::Result;
use std::net::SocketAddr;
use tracing::info;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
use ruvector_scipix::api::{state::AppState, ApiServer};
#[tokio::main]
async fn main() -> Result<()> {
// Initialize tracing
tracing_subscriber::registry()
.with(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| "scipix_server=debug,tower_http=debug,axum=trace".into()),
)
.with(tracing_subscriber::fmt::layer())
.init();
info!("Initializing Scipix API Server");
// Load configuration from environment
dotenvy::dotenv().ok();
// Create application state
let state = AppState::new();
// Parse server address
let addr = std::env::var("SERVER_ADDR")
.unwrap_or_else(|_| "127.0.0.1:3000".to_string())
.parse::<SocketAddr>()?;
// Create and start server
let server = ApiServer::new(state, addr);
server.start().await?;
Ok(())
}

488
examples/scipix/src/cache/mod.rs vendored Normal file
View File

@@ -0,0 +1,488 @@
//! Vector-based intelligent caching for Scipix OCR results
//!
//! Uses ruvector-core for efficient similarity search and LRU eviction.
use crate::config::CacheConfig;
use crate::error::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{SystemTime, UNIX_EPOCH};
/// Cached OCR result with metadata
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachedResult {
/// LaTeX output
pub latex: String,
/// Alternative formats (MathML, AsciiMath)
pub alternatives: HashMap<String, String>,
/// Confidence score
pub confidence: f32,
/// Cache timestamp
pub timestamp: u64,
/// Access count
pub access_count: usize,
/// Image hash
pub image_hash: String,
}
/// Cache entry with vector embedding
#[derive(Debug, Clone)]
struct CacheEntry {
/// Vector embedding of image
embedding: Vec<f32>,
/// Cached result
result: CachedResult,
/// Last access time
last_access: u64,
}
/// Vector-based cache manager
pub struct CacheManager {
/// Configuration
config: CacheConfig,
/// Cache entries (thread-safe)
entries: Arc<RwLock<HashMap<String, CacheEntry>>>,
/// LRU tracking
lru_order: Arc<RwLock<Vec<String>>>,
/// Cache statistics
stats: Arc<RwLock<CacheStats>>,
}
/// Cache statistics
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CacheStats {
/// Total cache hits
pub hits: u64,
/// Total cache misses
pub misses: u64,
/// Total entries
pub entries: usize,
/// Total evictions
pub evictions: u64,
/// Average similarity score for hits
pub avg_similarity: f32,
}
impl CacheStats {
/// Calculate hit rate
pub fn hit_rate(&self) -> f32 {
if self.hits + self.misses == 0 {
return 0.0;
}
self.hits as f32 / (self.hits + self.misses) as f32
}
}
impl CacheManager {
/// Create new cache manager
///
/// # Arguments
///
/// * `config` - Cache configuration
///
/// # Examples
///
/// ```rust
/// use ruvector_scipix::{CacheConfig, cache::CacheManager};
///
/// let config = CacheConfig {
/// enabled: true,
/// capacity: 1000,
/// similarity_threshold: 0.95,
/// ttl: 3600,
/// vector_dimension: 512,
/// persistent: false,
/// cache_dir: ".cache".to_string(),
/// };
///
/// let cache = CacheManager::new(config);
/// ```
pub fn new(config: CacheConfig) -> Self {
Self {
config,
entries: Arc::new(RwLock::new(HashMap::new())),
lru_order: Arc::new(RwLock::new(Vec::new())),
stats: Arc::new(RwLock::new(CacheStats::default())),
}
}
/// Generate embedding for image
///
/// This is a placeholder - in production, use actual vision model
fn generate_embedding(&self, image_data: &[u8]) -> Result<Vec<f32>> {
// Placeholder: Simple hash-based embedding
// In production: Use Vision Transformer or similar
let hash = self.hash_image(image_data);
let mut embedding = vec![0.0; self.config.vector_dimension];
for (i, byte) in hash.as_bytes().iter().enumerate() {
if i < embedding.len() {
embedding[i] = *byte as f32 / 255.0;
}
}
Ok(embedding)
}
/// Hash image data
fn hash_image(&self, image_data: &[u8]) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
image_data.hash(&mut hasher);
format!("{:x}", hasher.finish())
}
/// Calculate cosine similarity between vectors
fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
return 0.0;
}
dot_product / (norm_a * norm_b)
}
/// Look up cached result by image similarity
///
/// # Arguments
///
/// * `image_data` - Raw image bytes
///
/// # Returns
///
/// Cached result if similarity exceeds threshold, None otherwise
pub fn lookup(&self, image_data: &[u8]) -> Result<Option<CachedResult>> {
if !self.config.enabled {
return Ok(None);
}
let embedding = self.generate_embedding(image_data)?;
let hash = self.hash_image(image_data);
let entries = self.entries.read().unwrap();
// First try exact hash match
if let Some(entry) = entries.get(&hash) {
if !self.is_expired(&entry) {
self.record_hit();
self.update_lru(&hash);
return Ok(Some(entry.result.clone()));
}
}
// Then try similarity search
let mut best_match: Option<(String, f32, CachedResult)> = None;
for (key, entry) in entries.iter() {
if self.is_expired(entry) {
continue;
}
let similarity = self.cosine_similarity(&embedding, &entry.embedding);
if similarity >= self.config.similarity_threshold {
if best_match.is_none() || similarity > best_match.as_ref().unwrap().1 {
best_match = Some((key.clone(), similarity, entry.result.clone()));
}
}
}
if let Some((key, similarity, result)) = best_match {
self.record_hit_with_similarity(similarity);
self.update_lru(&key);
Ok(Some(result))
} else {
self.record_miss();
Ok(None)
}
}
/// Store result in cache
///
/// # Arguments
///
/// * `image_data` - Raw image bytes
/// * `result` - OCR result to cache
pub fn store(&self, image_data: &[u8], result: CachedResult) -> Result<()> {
if !self.config.enabled {
return Ok(());
}
let embedding = self.generate_embedding(image_data)?;
let hash = self.hash_image(image_data);
let entry = CacheEntry {
embedding,
result,
last_access: self.current_timestamp(),
};
let mut entries = self.entries.write().unwrap();
// Check if we need to evict
if entries.len() >= self.config.capacity && !entries.contains_key(&hash) {
self.evict_lru(&mut entries);
}
entries.insert(hash.clone(), entry);
self.update_lru(&hash);
self.update_stats_entries(entries.len());
Ok(())
}
/// Check if entry is expired
fn is_expired(&self, entry: &CacheEntry) -> bool {
let current = self.current_timestamp();
current - entry.last_access > self.config.ttl
}
/// Get current timestamp
fn current_timestamp(&self) -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs()
}
/// Evict least recently used entry
fn evict_lru(&self, entries: &mut HashMap<String, CacheEntry>) {
let mut lru = self.lru_order.write().unwrap();
if let Some(key) = lru.first() {
entries.remove(key);
lru.remove(0);
self.record_eviction();
}
}
/// Update LRU order
fn update_lru(&self, key: &str) {
let mut lru = self.lru_order.write().unwrap();
lru.retain(|k| k != key);
lru.push(key.to_string());
}
/// Record cache hit
fn record_hit(&self) {
let mut stats = self.stats.write().unwrap();
stats.hits += 1;
}
/// Record cache hit with similarity
fn record_hit_with_similarity(&self, similarity: f32) {
let mut stats = self.stats.write().unwrap();
stats.hits += 1;
// Update rolling average
let total = stats.hits as f32;
stats.avg_similarity = (stats.avg_similarity * (total - 1.0) + similarity) / total;
}
/// Record cache miss
fn record_miss(&self) {
let mut stats = self.stats.write().unwrap();
stats.misses += 1;
}
/// Record eviction
fn record_eviction(&self) {
let mut stats = self.stats.write().unwrap();
stats.evictions += 1;
}
/// Update entry count
fn update_stats_entries(&self, count: usize) {
let mut stats = self.stats.write().unwrap();
stats.entries = count;
}
/// Get cache statistics
pub fn stats(&self) -> CacheStats {
self.stats.read().unwrap().clone()
}
/// Clear all cache entries
pub fn clear(&self) {
let mut entries = self.entries.write().unwrap();
let mut lru = self.lru_order.write().unwrap();
entries.clear();
lru.clear();
self.update_stats_entries(0);
}
/// Remove expired entries
pub fn cleanup(&self) {
let mut entries = self.entries.write().unwrap();
let mut lru = self.lru_order.write().unwrap();
let expired: Vec<String> = entries
.iter()
.filter(|(_, entry)| self.is_expired(entry))
.map(|(key, _)| key.clone())
.collect();
for key in &expired {
entries.remove(key);
lru.retain(|k| k != key);
}
self.update_stats_entries(entries.len());
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config() -> CacheConfig {
CacheConfig {
enabled: true,
capacity: 100,
similarity_threshold: 0.95,
ttl: 3600,
vector_dimension: 128,
persistent: false,
cache_dir: ".cache/test".to_string(),
}
}
fn test_result() -> CachedResult {
CachedResult {
latex: r"\frac{x^2}{2}".to_string(),
alternatives: HashMap::new(),
confidence: 0.95,
timestamp: 0,
access_count: 0,
image_hash: "test".to_string(),
}
}
#[test]
fn test_cache_creation() {
let config = test_config();
let cache = CacheManager::new(config);
assert_eq!(cache.stats().hits, 0);
assert_eq!(cache.stats().misses, 0);
}
#[test]
fn test_store_and_lookup() {
let config = test_config();
let cache = CacheManager::new(config);
let image_data = b"test image data";
let result = test_result();
cache.store(image_data, result.clone()).unwrap();
let lookup_result = cache.lookup(image_data).unwrap();
assert!(lookup_result.is_some());
assert_eq!(lookup_result.unwrap().latex, result.latex);
}
#[test]
fn test_cache_miss() {
let config = test_config();
let cache = CacheManager::new(config);
let image_data = b"nonexistent image";
let lookup_result = cache.lookup(image_data).unwrap();
assert!(lookup_result.is_none());
assert_eq!(cache.stats().misses, 1);
}
#[test]
fn test_cache_hit_rate() {
let config = test_config();
let cache = CacheManager::new(config);
let image_data = b"test image";
let result = test_result();
// Store and lookup once
cache.store(image_data, result).unwrap();
cache.lookup(image_data).unwrap();
// Lookup again (hit)
cache.lookup(image_data).unwrap();
// Lookup different image (miss)
cache.lookup(b"different image").unwrap();
let stats = cache.stats();
assert_eq!(stats.hits, 2);
assert_eq!(stats.misses, 1);
assert!((stats.hit_rate() - 0.666).abs() < 0.01);
}
#[test]
fn test_cosine_similarity() {
let config = test_config();
let cache = CacheManager::new(config);
let vec_a = vec![1.0, 0.0, 0.0];
let vec_b = vec![1.0, 0.0, 0.0];
let vec_c = vec![0.0, 1.0, 0.0];
assert!((cache.cosine_similarity(&vec_a, &vec_b) - 1.0).abs() < 0.01);
assert!((cache.cosine_similarity(&vec_a, &vec_c) - 0.0).abs() < 0.01);
}
#[test]
fn test_cache_clear() {
let config = test_config();
let cache = CacheManager::new(config);
let image_data = b"test image";
let result = test_result();
cache.store(image_data, result).unwrap();
assert_eq!(cache.stats().entries, 1);
cache.clear();
assert_eq!(cache.stats().entries, 0);
}
#[test]
fn test_disabled_cache() {
let mut config = test_config();
config.enabled = false;
let cache = CacheManager::new(config);
let image_data = b"test image";
let result = test_result();
cache.store(image_data, result).unwrap();
let lookup_result = cache.lookup(image_data).unwrap();
assert!(lookup_result.is_none());
}
}

View File

@@ -0,0 +1,399 @@
use anyhow::{Context, Result};
use clap::Args;
use glob::glob;
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::Semaphore;
use tracing::{debug, error, info, warn};
use super::{OcrConfig, OcrResult};
use crate::cli::{output, Cli, OutputFormat};
/// Process multiple files in batch mode
#[derive(Args, Debug, Clone)]
pub struct BatchArgs {
/// Input pattern (glob) or directory
#[arg(value_name = "PATTERN", help = "Input pattern (glob) or directory")]
pub pattern: String,
/// Output directory for results
#[arg(
short,
long,
value_name = "DIR",
help = "Output directory for results (default: stdout as JSON array)"
)]
pub output: Option<PathBuf>,
/// Number of parallel workers
#[arg(
short,
long,
default_value = "4",
help = "Number of parallel processing workers"
)]
pub parallel: usize,
/// Minimum confidence threshold (0.0 to 1.0)
#[arg(
short = 't',
long,
default_value = "0.7",
help = "Minimum confidence threshold for results"
)]
pub threshold: f64,
/// Continue on errors
#[arg(
short = 'c',
long,
help = "Continue processing even if some files fail"
)]
pub continue_on_error: bool,
/// Maximum retry attempts per file
#[arg(
short = 'r',
long,
default_value = "2",
help = "Maximum retry attempts per file on failure"
)]
pub max_retries: usize,
/// Save individual results as separate files
#[arg(long, help = "Save each result as a separate file (requires --output)")]
pub separate_files: bool,
/// Recursive directory search
#[arg(short = 'R', long, help = "Recursively search directories")]
pub recursive: bool,
}
pub async fn execute(args: BatchArgs, cli: &Cli) -> Result<()> {
info!("Starting batch processing with pattern: {}", args.pattern);
// Load configuration
let config = Arc::new(load_config(cli.config.as_ref())?);
// Expand pattern to file list
let files = collect_files(&args)?;
if files.is_empty() {
anyhow::bail!("No files found matching pattern: {}", args.pattern);
}
info!("Found {} files to process", files.len());
// Create output directory if needed
if let Some(output_dir) = &args.output {
std::fs::create_dir_all(output_dir).context("Failed to create output directory")?;
}
// Process files in parallel with progress bars
let results = process_files_parallel(files, &args, &config, cli.quiet).await?;
// Filter by confidence threshold
let (passed, failed): (Vec<_>, Vec<_>) = results
.into_iter()
.partition(|r| r.confidence >= args.threshold);
info!(
"Processing complete: {} passed, {} failed threshold",
passed.len(),
failed.len()
);
// Save or display results
if let Some(output_dir) = &args.output {
save_results(&passed, output_dir, &cli.format, args.separate_files)?;
if !cli.quiet {
println!("Results saved to: {}", output_dir.display());
}
} else {
// Output as JSON array to stdout
let json = serde_json::to_string_pretty(&passed).context("Failed to serialize results")?;
println!("{}", json);
}
// Display summary
if !cli.quiet {
output::print_batch_summary(&passed, &failed, args.threshold);
}
// Return error if any files failed and continue_on_error is false
if !failed.is_empty() && !args.continue_on_error {
anyhow::bail!("{} files failed confidence threshold", failed.len());
}
Ok(())
}
fn collect_files(args: &BatchArgs) -> Result<Vec<PathBuf>> {
let mut files = Vec::new();
let path = PathBuf::from(&args.pattern);
if path.is_dir() {
// Directory mode
let pattern = if args.recursive {
format!("{}/**/*", args.pattern)
} else {
format!("{}/*", args.pattern)
};
for entry in glob(&pattern).context("Failed to read glob pattern")? {
match entry {
Ok(path) => {
if path.is_file() {
files.push(path);
}
}
Err(e) => warn!("Failed to read entry: {}", e),
}
}
} else {
// Glob pattern mode
for entry in glob(&args.pattern).context("Failed to read glob pattern")? {
match entry {
Ok(path) => {
if path.is_file() {
files.push(path);
}
}
Err(e) => warn!("Failed to read entry: {}", e),
}
}
}
Ok(files)
}
async fn process_files_parallel(
files: Vec<PathBuf>,
args: &BatchArgs,
config: &Arc<OcrConfig>,
quiet: bool,
) -> Result<Vec<OcrResult>> {
let semaphore = Arc::new(Semaphore::new(args.parallel));
let multi_progress = Arc::new(MultiProgress::new());
let overall_progress = if !quiet {
let pb = multi_progress.add(ProgressBar::new(files.len() as u64));
pb.set_style(
ProgressStyle::default_bar()
.template(
"{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
)
.unwrap()
.progress_chars("#>-"),
);
Some(pb)
} else {
None
};
let mut handles = Vec::new();
for (_idx, file) in files.into_iter().enumerate() {
let semaphore = semaphore.clone();
let config = config.clone();
let multi_progress = multi_progress.clone();
let overall_progress = overall_progress.clone();
let max_retries = args.max_retries;
let handle = tokio::spawn(async move {
let _permit = semaphore.acquire().await.unwrap();
let file_progress = if !quiet {
let pb = multi_progress.insert_before(
&overall_progress.as_ref().unwrap(),
ProgressBar::new_spinner(),
);
pb.set_style(
ProgressStyle::default_spinner()
.template("{spinner:.green} {msg}")
.unwrap(),
);
pb.set_message(format!("[{}] Processing...", file.display()));
Some(pb)
} else {
None
};
let result = process_with_retry(&file, &config, max_retries).await;
if let Some(pb) = &file_progress {
match &result {
Ok(r) => pb.finish_with_message(format!(
"[{}] ✓ Confidence: {:.2}%",
file.display(),
r.confidence * 100.0
)),
Err(e) => {
pb.finish_with_message(format!("[{}] ✗ Error: {}", file.display(), e))
}
}
}
if let Some(pb) = &overall_progress {
pb.inc(1);
}
result
});
handles.push(handle);
}
// Wait for all tasks to complete
let mut results = Vec::new();
for handle in handles {
match handle.await {
Ok(Ok(result)) => results.push(result),
Ok(Err(e)) => error!("Processing failed: {}", e),
Err(e) => error!("Task panicked: {}", e),
}
}
if let Some(pb) = overall_progress {
pb.finish_with_message("Batch processing complete");
}
Ok(results)
}
async fn process_with_retry(
file: &PathBuf,
config: &OcrConfig,
max_retries: usize,
) -> Result<OcrResult> {
let mut attempts = 0;
let mut last_error = None;
while attempts <= max_retries {
match process_single_file(file, config).await {
Ok(result) => return Ok(result),
Err(e) => {
attempts += 1;
last_error = Some(e);
if attempts <= max_retries {
debug!("Retry {}/{} for {}", attempts, max_retries, file.display());
tokio::time::sleep(tokio::time::Duration::from_millis(100 * attempts as u64))
.await;
}
}
}
}
Err(last_error.unwrap())
}
async fn process_single_file(file: &PathBuf, _config: &OcrConfig) -> Result<OcrResult> {
// TODO: Implement actual OCR processing
// For now, return a mock result
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
// Simulate varying confidence
let confidence = 0.7 + (rand::random::<f64>() * 0.3);
Ok(OcrResult {
file: file.clone(),
text: format!("OCR text from {}", file.display()),
latex: Some(format!(r"\text{{Content from {}}}", file.display())),
confidence,
processing_time_ms: 50,
errors: Vec::new(),
})
}
fn save_results(
results: &[OcrResult],
output_dir: &PathBuf,
format: &OutputFormat,
separate_files: bool,
) -> Result<()> {
if separate_files {
// Save each result as a separate file
for (idx, result) in results.iter().enumerate() {
let filename = format!(
"result_{:04}.{}",
idx,
match format {
OutputFormat::Json => "json",
OutputFormat::Latex => "tex",
OutputFormat::Markdown => "md",
OutputFormat::MathMl => "xml",
OutputFormat::Text => "txt",
}
);
let output_path = output_dir.join(filename);
let content = format_single_result(result, format)?;
std::fs::write(&output_path, content)
.context(format!("Failed to write {}", output_path.display()))?;
}
} else {
// Save all results as a single file
let filename = format!(
"results.{}",
match format {
OutputFormat::Json => "json",
OutputFormat::Latex => "tex",
OutputFormat::Markdown => "md",
OutputFormat::MathMl => "xml",
OutputFormat::Text => "txt",
}
);
let output_path = output_dir.join(filename);
let content = format_batch_results(results, format)?;
std::fs::write(&output_path, content).context("Failed to write results file")?;
}
Ok(())
}
fn format_single_result(result: &OcrResult, format: &OutputFormat) -> Result<String> {
match format {
OutputFormat::Json => {
serde_json::to_string_pretty(result).context("Failed to serialize result")
}
OutputFormat::Text => Ok(result.text.clone()),
OutputFormat::Latex => Ok(result.latex.clone().unwrap_or_else(|| result.text.clone())),
OutputFormat::Markdown => Ok(format!("# {}\n\n{}\n", result.file.display(), result.text)),
OutputFormat::MathMl => Ok(format!(
"<math xmlns=\"http://www.w3.org/1998/Math/MathML\">\n {}\n</math>",
result.text
)),
}
}
fn format_batch_results(results: &[OcrResult], format: &OutputFormat) -> Result<String> {
match format {
OutputFormat::Json => {
serde_json::to_string_pretty(results).context("Failed to serialize results")
}
_ => {
let mut output = String::new();
for result in results {
output.push_str(&format_single_result(result, format)?);
output.push_str("\n\n---\n\n");
}
Ok(output)
}
}
}
fn load_config(config_path: Option<&PathBuf>) -> Result<OcrConfig> {
if let Some(path) = config_path {
let content = std::fs::read_to_string(path).context("Failed to read config file")?;
toml::from_str(&content).context("Failed to parse config file")
} else {
Ok(OcrConfig::default())
}
}

View File

@@ -0,0 +1,272 @@
use anyhow::{Context, Result};
use clap::{Args, Subcommand};
use dialoguer::{theme::ColorfulTheme, Confirm, Input};
use std::path::PathBuf;
use tracing::info;
use super::OcrConfig;
use crate::cli::Cli;
/// Manage configuration
#[derive(Args, Debug, Clone)]
pub struct ConfigArgs {
#[command(subcommand)]
pub command: ConfigCommand,
}
#[derive(Subcommand, Debug, Clone)]
pub enum ConfigCommand {
/// Generate default configuration file
Init {
/// Output path for config file
#[arg(short, long, default_value = "scipix.toml")]
output: PathBuf,
/// Overwrite existing file
#[arg(short, long)]
force: bool,
},
/// Validate configuration file
Validate {
/// Path to config file to validate
#[arg(value_name = "FILE")]
file: PathBuf,
},
/// Show current configuration
Show {
/// Path to config file (default: from --config or scipix.toml)
#[arg(value_name = "FILE")]
file: Option<PathBuf>,
},
/// Edit configuration interactively
Edit {
/// Path to config file to edit
#[arg(value_name = "FILE")]
file: PathBuf,
},
/// Get configuration directory path
Path,
}
pub async fn execute(args: ConfigArgs, cli: &Cli) -> Result<()> {
match args.command {
ConfigCommand::Init { output, force } => {
init_config(&output, force)?;
}
ConfigCommand::Validate { file } => {
validate_config(&file)?;
}
ConfigCommand::Show { file } => {
show_config(file.or(cli.config.clone()))?;
}
ConfigCommand::Edit { file } => {
edit_config(&file)?;
}
ConfigCommand::Path => {
show_config_path()?;
}
}
Ok(())
}
fn init_config(output: &PathBuf, force: bool) -> Result<()> {
if output.exists() && !force {
anyhow::bail!(
"Config file already exists: {} (use --force to overwrite)",
output.display()
);
}
let config = OcrConfig::default();
let toml = toml::to_string_pretty(&config).context("Failed to serialize config")?;
std::fs::write(output, toml).context("Failed to write config file")?;
info!("Configuration file created: {}", output.display());
println!("✓ Created configuration file: {}", output.display());
println!("\nTo use this config, run:");
println!(" scipix-cli --config {} <command>", output.display());
println!("\nOr set environment variable:");
println!(" export MATHPIX_CONFIG={}", output.display());
Ok(())
}
fn validate_config(file: &PathBuf) -> Result<()> {
if !file.exists() {
anyhow::bail!("Config file not found: {}", file.display());
}
let content = std::fs::read_to_string(file).context("Failed to read config file")?;
let config: OcrConfig = toml::from_str(&content).context("Failed to parse config file")?;
// Validate configuration values
if config.min_confidence < 0.0 || config.min_confidence > 1.0 {
anyhow::bail!("min_confidence must be between 0.0 and 1.0");
}
if config.max_image_size == 0 {
anyhow::bail!("max_image_size must be greater than 0");
}
if config.supported_extensions.is_empty() {
anyhow::bail!("supported_extensions cannot be empty");
}
println!("✓ Configuration is valid");
println!("\nSettings:");
println!(" Min confidence: {}", config.min_confidence);
println!(" Max image size: {} bytes", config.max_image_size);
println!(
" Supported extensions: {}",
config.supported_extensions.join(", ")
);
if let Some(endpoint) = &config.api_endpoint {
println!(" API endpoint: {}", endpoint);
}
Ok(())
}
fn show_config(file: Option<PathBuf>) -> Result<()> {
let config_path = file.unwrap_or_else(|| PathBuf::from("scipix.toml"));
if !config_path.exists() {
println!("No configuration file found.");
println!("\nCreate one with:");
println!(" scipix-cli config init");
return Ok(());
}
let content = std::fs::read_to_string(&config_path).context("Failed to read config file")?;
println!("Configuration from: {}\n", config_path.display());
println!("{}", content);
Ok(())
}
fn edit_config(file: &PathBuf) -> Result<()> {
if !file.exists() {
anyhow::bail!(
"Config file not found: {} (use 'config init' to create)",
file.display()
);
}
let content = std::fs::read_to_string(file).context("Failed to read config file")?;
let mut config: OcrConfig = toml::from_str(&content).context("Failed to parse config file")?;
let theme = ColorfulTheme::default();
println!("Interactive Configuration Editor\n");
// Edit min_confidence
config.min_confidence = Input::with_theme(&theme)
.with_prompt("Minimum confidence threshold (0.0-1.0)")
.default(config.min_confidence)
.validate_with(|v: &f64| {
if *v >= 0.0 && *v <= 1.0 {
Ok(())
} else {
Err("Value must be between 0.0 and 1.0")
}
})
.interact_text()
.context("Failed to read input")?;
// Edit max_image_size
let max_size_mb = config.max_image_size / (1024 * 1024);
let new_size_mb: usize = Input::with_theme(&theme)
.with_prompt("Maximum image size (MB)")
.default(max_size_mb)
.interact_text()
.context("Failed to read input")?;
config.max_image_size = new_size_mb * 1024 * 1024;
// Edit API endpoint
if config.api_endpoint.is_some() {
let edit_endpoint = Confirm::with_theme(&theme)
.with_prompt("Edit API endpoint?")
.default(false)
.interact()
.context("Failed to read input")?;
if edit_endpoint {
let endpoint: String = Input::with_theme(&theme)
.with_prompt("API endpoint URL")
.allow_empty(true)
.interact_text()
.context("Failed to read input")?;
config.api_endpoint = if endpoint.is_empty() {
None
} else {
Some(endpoint)
};
}
} else {
let add_endpoint = Confirm::with_theme(&theme)
.with_prompt("Add API endpoint?")
.default(false)
.interact()
.context("Failed to read input")?;
if add_endpoint {
let endpoint: String = Input::with_theme(&theme)
.with_prompt("API endpoint URL")
.interact_text()
.context("Failed to read input")?;
config.api_endpoint = Some(endpoint);
}
}
// Save configuration
let save = Confirm::with_theme(&theme)
.with_prompt("Save changes?")
.default(true)
.interact()
.context("Failed to read input")?;
if save {
let toml = toml::to_string_pretty(&config).context("Failed to serialize config")?;
std::fs::write(file, toml).context("Failed to write config file")?;
println!("\n✓ Configuration saved to: {}", file.display());
} else {
println!("\nChanges discarded.");
}
Ok(())
}
fn show_config_path() -> Result<()> {
if let Some(config_dir) = dirs::config_dir() {
let app_config = config_dir.join("scipix");
println!("Default config directory: {}", app_config.display());
if !app_config.exists() {
println!("\nDirectory does not exist. Create it with:");
println!(" mkdir -p {}", app_config.display());
}
} else {
println!("Could not determine config directory");
}
println!("\nYou can also use a custom config file:");
println!(" scipix-cli --config /path/to/config.toml <command>");
println!("\nOr set environment variable:");
println!(" export MATHPIX_CONFIG=/path/to/config.toml");
Ok(())
}

View File

@@ -0,0 +1,955 @@
//! Doctor command for environment analysis and configuration optimization
//!
//! Analyzes the system environment and provides recommendations for optimal
//! SciPix configuration based on available hardware and software capabilities.
use anyhow::Result;
use clap::Args;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
/// Arguments for the doctor command
#[derive(Args, Debug, Clone)]
pub struct DoctorArgs {
/// Run in fix mode to automatically apply recommendations
#[arg(long, help = "Automatically apply safe fixes")]
pub fix: bool,
/// Output detailed diagnostic information
#[arg(long, short, help = "Show detailed diagnostic information")]
pub verbose: bool,
/// Output results as JSON
#[arg(long, help = "Output results as JSON")]
pub json: bool,
/// Check only specific category (cpu, memory, config, deps, all)
#[arg(long, default_value = "all", help = "Category to check")]
pub check: CheckCategory,
/// Path to configuration file to validate
#[arg(long, help = "Path to configuration file to validate")]
pub config_path: Option<PathBuf>,
}
#[derive(Debug, Clone, Copy, clap::ValueEnum, Default)]
pub enum CheckCategory {
#[default]
All,
Cpu,
Memory,
Config,
Deps,
Network,
}
/// Status of a diagnostic check
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CheckStatus {
Pass,
Warning,
Fail,
Info,
}
impl std::fmt::Display for CheckStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CheckStatus::Pass => write!(f, ""),
CheckStatus::Warning => write!(f, ""),
CheckStatus::Fail => write!(f, ""),
CheckStatus::Info => write!(f, ""),
}
}
}
/// A single diagnostic check result
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiagnosticCheck {
pub name: String,
pub category: String,
pub status: CheckStatus,
pub message: String,
pub recommendation: Option<String>,
pub auto_fixable: bool,
}
/// Complete diagnostic report
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiagnosticReport {
pub timestamp: String,
pub system_info: SystemInfo,
pub checks: Vec<DiagnosticCheck>,
pub recommendations: Vec<String>,
pub optimal_config: OptimalConfig,
}
/// System information gathered during diagnosis
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SystemInfo {
pub os: String,
pub arch: String,
pub cpu_count: usize,
pub cpu_brand: String,
pub total_memory_mb: u64,
pub available_memory_mb: u64,
pub simd_features: SimdFeatures,
}
/// SIMD feature detection results
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SimdFeatures {
pub sse2: bool,
pub sse4_1: bool,
pub sse4_2: bool,
pub avx: bool,
pub avx2: bool,
pub avx512f: bool,
pub neon: bool,
pub best_available: String,
}
/// Optimal configuration recommendations
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimalConfig {
pub batch_size: usize,
pub worker_threads: usize,
pub simd_backend: String,
pub memory_limit_mb: u64,
pub preprocessing_mode: String,
pub cache_enabled: bool,
pub cache_size_mb: u64,
}
/// Execute the doctor command
pub async fn execute(args: DoctorArgs) -> Result<()> {
if !args.json {
println!("🩺 SciPix Doctor - Environment Analysis\n");
println!("═══════════════════════════════════════════════════════════\n");
}
let mut checks = Vec::new();
// Gather system information
let system_info = gather_system_info();
// Run checks based on category
match args.check {
CheckCategory::All => {
checks.extend(check_cpu(&system_info, args.verbose));
checks.extend(check_memory(&system_info, args.verbose));
checks.extend(check_dependencies(args.verbose));
checks.extend(check_config(&args.config_path, args.verbose));
checks.extend(check_network(args.verbose).await);
}
CheckCategory::Cpu => {
checks.extend(check_cpu(&system_info, args.verbose));
}
CheckCategory::Memory => {
checks.extend(check_memory(&system_info, args.verbose));
}
CheckCategory::Config => {
checks.extend(check_config(&args.config_path, args.verbose));
}
CheckCategory::Deps => {
checks.extend(check_dependencies(args.verbose));
}
CheckCategory::Network => {
checks.extend(check_network(args.verbose).await);
}
}
// Generate optimal configuration
let optimal_config = generate_optimal_config(&system_info);
// Collect recommendations
let recommendations: Vec<String> = checks
.iter()
.filter_map(|c| c.recommendation.clone())
.collect();
// Create report
let report = DiagnosticReport {
timestamp: chrono::Utc::now().to_rfc3339(),
system_info: system_info.clone(),
checks: checks.clone(),
recommendations: recommendations.clone(),
optimal_config: optimal_config.clone(),
};
if args.json {
println!("{}", serde_json::to_string_pretty(&report)?);
return Ok(());
}
// Print system info
print_system_info(&system_info);
// Print check results
print_check_results(&checks);
// Print recommendations
if !recommendations.is_empty() {
println!("\n📋 Recommendations:");
println!("───────────────────────────────────────────────────────────");
for (i, rec) in recommendations.iter().enumerate() {
println!(" {}. {}", i + 1, rec);
}
}
// Print optimal configuration
print_optimal_config(&optimal_config);
// Apply fixes if requested
if args.fix {
apply_fixes(&checks).await?;
}
// Print summary
print_summary(&checks);
Ok(())
}
fn gather_system_info() -> SystemInfo {
let cpu_count = num_cpus::get();
// Get CPU brand string
let cpu_brand = get_cpu_brand();
// Get memory info
let (total_memory_mb, available_memory_mb) = get_memory_info();
// Detect SIMD features
let simd_features = detect_simd_features();
SystemInfo {
os: std::env::consts::OS.to_string(),
arch: std::env::consts::ARCH.to_string(),
cpu_count,
cpu_brand,
total_memory_mb,
available_memory_mb,
simd_features,
}
}
fn get_cpu_brand() -> String {
#[cfg(target_arch = "x86_64")]
{
if let Some(brand) = get_x86_cpu_brand() {
return brand;
}
}
// Fallback
format!("{} processor", std::env::consts::ARCH)
}
#[cfg(target_arch = "x86_64")]
fn get_x86_cpu_brand() -> Option<String> {
// Try to read from /proc/cpuinfo on Linux
if let Ok(cpuinfo) = std::fs::read_to_string("/proc/cpuinfo") {
for line in cpuinfo.lines() {
if line.starts_with("model name") {
if let Some(brand) = line.split(':').nth(1) {
return Some(brand.trim().to_string());
}
}
}
}
None
}
#[cfg(not(target_arch = "x86_64"))]
fn get_x86_cpu_brand() -> Option<String> {
None
}
fn get_memory_info() -> (u64, u64) {
// Try to read from /proc/meminfo on Linux
if let Ok(meminfo) = std::fs::read_to_string("/proc/meminfo") {
let mut total = 0u64;
let mut available = 0u64;
for line in meminfo.lines() {
if line.starts_with("MemTotal:") {
if let Some(kb) = parse_meminfo_value(line) {
total = kb / 1024; // Convert to MB
}
} else if line.starts_with("MemAvailable:") {
if let Some(kb) = parse_meminfo_value(line) {
available = kb / 1024; // Convert to MB
}
}
}
if total > 0 {
return (total, available);
}
}
// Fallback values
(8192, 4096)
}
fn parse_meminfo_value(line: &str) -> Option<u64> {
line.split_whitespace().nth(1).and_then(|s| s.parse().ok())
}
fn detect_simd_features() -> SimdFeatures {
let mut features = SimdFeatures {
sse2: false,
sse4_1: false,
sse4_2: false,
avx: false,
avx2: false,
avx512f: false,
neon: false,
best_available: "scalar".to_string(),
};
#[cfg(target_arch = "x86_64")]
{
features.sse2 = is_x86_feature_detected!("sse2");
features.sse4_1 = is_x86_feature_detected!("sse4.1");
features.sse4_2 = is_x86_feature_detected!("sse4.2");
features.avx = is_x86_feature_detected!("avx");
features.avx2 = is_x86_feature_detected!("avx2");
features.avx512f = is_x86_feature_detected!("avx512f");
if features.avx512f {
features.best_available = "AVX-512".to_string();
} else if features.avx2 {
features.best_available = "AVX2".to_string();
} else if features.avx {
features.best_available = "AVX".to_string();
} else if features.sse4_2 {
features.best_available = "SSE4.2".to_string();
} else if features.sse2 {
features.best_available = "SSE2".to_string();
}
}
#[cfg(target_arch = "aarch64")]
{
features.neon = true; // NEON is always available on AArch64
features.best_available = "NEON".to_string();
}
features
}
fn check_cpu(system_info: &SystemInfo, verbose: bool) -> Vec<DiagnosticCheck> {
let mut checks = Vec::new();
// CPU count check
let cpu_status = if system_info.cpu_count >= 8 {
CheckStatus::Pass
} else if system_info.cpu_count >= 4 {
CheckStatus::Warning
} else {
CheckStatus::Fail
};
checks.push(DiagnosticCheck {
name: "CPU Cores".to_string(),
category: "CPU".to_string(),
status: cpu_status,
message: format!("{} cores detected", system_info.cpu_count),
recommendation: if system_info.cpu_count < 4 {
Some(
"Consider running on a machine with more CPU cores for better batch processing"
.to_string(),
)
} else {
None
},
auto_fixable: false,
});
// SIMD check
let simd_status = match system_info.simd_features.best_available.as_str() {
"AVX-512" | "AVX2" => CheckStatus::Pass,
"AVX" | "SSE4.2" | "NEON" => CheckStatus::Warning,
_ => CheckStatus::Fail,
};
checks.push(DiagnosticCheck {
name: "SIMD Support".to_string(),
category: "CPU".to_string(),
status: simd_status,
message: format!(
"Best SIMD: {} (SSE2: {}, AVX: {}, AVX2: {}, AVX-512: {})",
system_info.simd_features.best_available,
if system_info.simd_features.sse2 {
""
} else {
""
},
if system_info.simd_features.avx {
""
} else {
""
},
if system_info.simd_features.avx2 {
""
} else {
""
},
if system_info.simd_features.avx512f {
""
} else {
""
},
),
recommendation: if simd_status == CheckStatus::Fail {
Some("Upgrade to a CPU with AVX2 support for 4x faster preprocessing".to_string())
} else {
None
},
auto_fixable: false,
});
if verbose {
checks.push(DiagnosticCheck {
name: "CPU Brand".to_string(),
category: "CPU".to_string(),
status: CheckStatus::Info,
message: system_info.cpu_brand.clone(),
recommendation: None,
auto_fixable: false,
});
}
checks
}
fn check_memory(system_info: &SystemInfo, verbose: bool) -> Vec<DiagnosticCheck> {
let mut checks = Vec::new();
// Total memory check
let mem_status = if system_info.total_memory_mb >= 16384 {
CheckStatus::Pass
} else if system_info.total_memory_mb >= 8192 {
CheckStatus::Warning
} else {
CheckStatus::Fail
};
checks.push(DiagnosticCheck {
name: "Total Memory".to_string(),
category: "Memory".to_string(),
status: mem_status,
message: format!("{} MB total", system_info.total_memory_mb),
recommendation: if system_info.total_memory_mb < 8192 {
Some("Consider upgrading to at least 8GB RAM for optimal batch processing".to_string())
} else {
None
},
auto_fixable: false,
});
// Available memory check
let avail_ratio = system_info.available_memory_mb as f64 / system_info.total_memory_mb as f64;
let avail_status = if avail_ratio >= 0.5 {
CheckStatus::Pass
} else if avail_ratio >= 0.25 {
CheckStatus::Warning
} else {
CheckStatus::Fail
};
checks.push(DiagnosticCheck {
name: "Available Memory".to_string(),
category: "Memory".to_string(),
status: avail_status,
message: format!(
"{} MB available ({:.1}%)",
system_info.available_memory_mb,
avail_ratio * 100.0
),
recommendation: if avail_status == CheckStatus::Fail {
Some("Close some applications to free up memory before batch processing".to_string())
} else {
None
},
auto_fixable: false,
});
if verbose {
// Memory per core
let mem_per_core = system_info.total_memory_mb / system_info.cpu_count as u64;
checks.push(DiagnosticCheck {
name: "Memory per Core".to_string(),
category: "Memory".to_string(),
status: CheckStatus::Info,
message: format!("{} MB/core", mem_per_core),
recommendation: None,
auto_fixable: false,
});
}
checks
}
fn check_dependencies(verbose: bool) -> Vec<DiagnosticCheck> {
let mut checks = Vec::new();
// Check for ONNX Runtime
let onnx_status = check_onnx_runtime();
checks.push(DiagnosticCheck {
name: "ONNX Runtime".to_string(),
category: "Dependencies".to_string(),
status: if onnx_status.0 {
CheckStatus::Pass
} else {
CheckStatus::Warning
},
message: onnx_status.1.clone(),
recommendation: if !onnx_status.0 {
Some(
"Install ONNX Runtime for neural network acceleration: https://onnxruntime.ai/"
.to_string(),
)
} else {
None
},
auto_fixable: false,
});
// Check for image processing libraries
checks.push(DiagnosticCheck {
name: "Image Processing".to_string(),
category: "Dependencies".to_string(),
status: CheckStatus::Pass,
message: "image crate available (built-in)".to_string(),
recommendation: None,
auto_fixable: false,
});
// Check for OpenSSL (for HTTPS)
let openssl_available = std::process::Command::new("openssl")
.arg("version")
.output()
.map(|o| o.status.success())
.unwrap_or(false);
checks.push(DiagnosticCheck {
name: "OpenSSL".to_string(),
category: "Dependencies".to_string(),
status: if openssl_available {
CheckStatus::Pass
} else {
CheckStatus::Warning
},
message: if openssl_available {
"OpenSSL available for HTTPS".to_string()
} else {
"OpenSSL not found".to_string()
},
recommendation: if !openssl_available {
Some("Install OpenSSL for secure API communication".to_string())
} else {
None
},
auto_fixable: false,
});
if verbose {
// Check Rust version
if let Ok(output) = std::process::Command::new("rustc")
.arg("--version")
.output()
{
let version = String::from_utf8_lossy(&output.stdout);
checks.push(DiagnosticCheck {
name: "Rust Compiler".to_string(),
category: "Dependencies".to_string(),
status: CheckStatus::Info,
message: version.trim().to_string(),
recommendation: None,
auto_fixable: false,
});
}
}
checks
}
fn check_onnx_runtime() -> (bool, String) {
// Check for ONNX runtime shared library
let lib_paths = [
"/usr/lib/libonnxruntime.so",
"/usr/local/lib/libonnxruntime.so",
"/opt/onnxruntime/lib/libonnxruntime.so",
];
for path in &lib_paths {
if std::path::Path::new(path).exists() {
return (true, format!("Found at {}", path));
}
}
// Check via environment variable
if std::env::var("ORT_DYLIB_PATH").is_ok() {
return (true, "Configured via ORT_DYLIB_PATH".to_string());
}
(
false,
"Not found (optional for ONNX acceleration)".to_string(),
)
}
fn check_config(config_path: &Option<PathBuf>, verbose: bool) -> Vec<DiagnosticCheck> {
let mut checks = Vec::new();
// Check for config file
let config_locations = [
config_path.clone(),
Some(PathBuf::from("scipix.toml")),
Some(PathBuf::from("config/scipix.toml")),
dirs::config_dir().map(|p| p.join("scipix/config.toml")),
];
let mut found_config = false;
for loc in config_locations.iter().flatten() {
if loc.exists() {
checks.push(DiagnosticCheck {
name: "Configuration File".to_string(),
category: "Config".to_string(),
status: CheckStatus::Pass,
message: format!("Found at {}", loc.display()),
recommendation: None,
auto_fixable: false,
});
found_config = true;
// Validate config content
if let Ok(content) = std::fs::read_to_string(loc) {
if content.contains("[api]") || content.contains("[processing]") {
checks.push(DiagnosticCheck {
name: "Config Validity".to_string(),
category: "Config".to_string(),
status: CheckStatus::Pass,
message: "Configuration file is valid".to_string(),
recommendation: None,
auto_fixable: false,
});
}
}
break;
}
}
if !found_config {
checks.push(DiagnosticCheck {
name: "Configuration File".to_string(),
category: "Config".to_string(),
status: CheckStatus::Info,
message: "No configuration file found (using defaults)".to_string(),
recommendation: Some("Create a scipix.toml for custom settings".to_string()),
auto_fixable: true,
});
}
// Check environment variables
let env_vars = [
("SCIPIX_API_KEY", "API authentication"),
("SCIPIX_MODEL_PATH", "Custom model path"),
("SCIPIX_CACHE_DIR", "Cache directory"),
];
for (var, desc) in &env_vars {
let status = if std::env::var(var).is_ok() {
CheckStatus::Pass
} else {
CheckStatus::Info
};
if verbose || status == CheckStatus::Pass {
checks.push(DiagnosticCheck {
name: format!("Env: {}", var),
category: "Config".to_string(),
status,
message: if status == CheckStatus::Pass {
format!("{} configured", desc)
} else {
format!("{} not set (optional)", desc)
},
recommendation: None,
auto_fixable: false,
});
}
}
checks
}
async fn check_network(verbose: bool) -> Vec<DiagnosticCheck> {
let mut checks = Vec::new();
// Check localhost binding
let localhost_available = tokio::net::TcpListener::bind("127.0.0.1:0").await.is_ok();
checks.push(DiagnosticCheck {
name: "Localhost Binding".to_string(),
category: "Network".to_string(),
status: if localhost_available {
CheckStatus::Pass
} else {
CheckStatus::Fail
},
message: if localhost_available {
"Can bind to localhost".to_string()
} else {
"Cannot bind to localhost".to_string()
},
recommendation: if !localhost_available {
Some("Check firewall settings and port availability".to_string())
} else {
None
},
auto_fixable: false,
});
// Check common ports
let ports_to_check = [(8080, "API server"), (3000, "Alternative API")];
for (port, desc) in &ports_to_check {
let available = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
.await
.is_ok();
if verbose || !available {
checks.push(DiagnosticCheck {
name: format!("Port {}", port),
category: "Network".to_string(),
status: if available {
CheckStatus::Pass
} else {
CheckStatus::Warning
},
message: if available {
format!("Port {} ({}) available", port, desc)
} else {
format!("Port {} ({}) in use", port, desc)
},
recommendation: if !available {
Some(format!(
"Free port {} or use --port to specify alternative",
port
))
} else {
None
},
auto_fixable: false,
});
}
}
checks
}
fn generate_optimal_config(system_info: &SystemInfo) -> OptimalConfig {
// Calculate optimal batch size based on memory
let batch_size = if system_info.available_memory_mb >= 8192 {
32
} else if system_info.available_memory_mb >= 4096 {
16
} else if system_info.available_memory_mb >= 2048 {
8
} else {
4
};
// Calculate worker threads (leave some headroom)
let worker_threads = (system_info.cpu_count as f64 * 0.75).ceil() as usize;
let worker_threads = worker_threads.max(2);
// Determine SIMD backend
let simd_backend = system_info.simd_features.best_available.clone();
// Memory limit (use 60% of available)
let memory_limit_mb = (system_info.available_memory_mb as f64 * 0.6) as u64;
// Preprocessing mode based on SIMD
let preprocessing_mode = if system_info.simd_features.avx2 || system_info.simd_features.neon {
"simd_optimized".to_string()
} else if system_info.simd_features.sse4_2 {
"simd_basic".to_string()
} else {
"scalar".to_string()
};
// Cache settings
let cache_enabled = system_info.available_memory_mb >= 2048;
let cache_size_mb = if cache_enabled {
(system_info.available_memory_mb as f64 * 0.1) as u64
} else {
0
};
OptimalConfig {
batch_size,
worker_threads,
simd_backend,
memory_limit_mb,
preprocessing_mode,
cache_enabled,
cache_size_mb,
}
}
fn print_system_info(info: &SystemInfo) {
println!("📊 System Information:");
println!("───────────────────────────────────────────────────────────");
println!(" OS: {} ({})", info.os, info.arch);
println!(" CPU: {}", info.cpu_brand);
println!(" Cores: {}", info.cpu_count);
println!(
" Memory: {} MB total, {} MB available",
info.total_memory_mb, info.available_memory_mb
);
println!(" Best SIMD: {}", info.simd_features.best_available);
println!();
}
fn print_check_results(checks: &[DiagnosticCheck]) {
println!("🔍 Diagnostic Checks:");
println!("───────────────────────────────────────────────────────────");
let mut current_category = String::new();
for check in checks {
if check.category != current_category {
if !current_category.is_empty() {
println!();
}
println!(" [{}]", check.category);
current_category = check.category.clone();
}
let status_color = match check.status {
CheckStatus::Pass => "\x1b[32m", // Green
CheckStatus::Warning => "\x1b[33m", // Yellow
CheckStatus::Fail => "\x1b[31m", // Red
CheckStatus::Info => "\x1b[36m", // Cyan
};
println!(
" {}{}\x1b[0m {} - {}",
status_color, check.status, check.name, check.message
);
}
println!();
}
fn print_optimal_config(config: &OptimalConfig) {
println!("\n⚙️ Optimal Configuration:");
println!("───────────────────────────────────────────────────────────");
println!(" batch_size: {}", config.batch_size);
println!(" worker_threads: {}", config.worker_threads);
println!(" simd_backend: {}", config.simd_backend);
println!(" memory_limit: {} MB", config.memory_limit_mb);
println!(" preprocessing: {}", config.preprocessing_mode);
println!(" cache_enabled: {}", config.cache_enabled);
if config.cache_enabled {
println!(" cache_size: {} MB", config.cache_size_mb);
}
println!("\n 📝 Example configuration (scipix.toml):");
println!(" ─────────────────────────────────────────");
println!(" [processing]");
println!(" batch_size = {}", config.batch_size);
println!(" worker_threads = {}", config.worker_threads);
println!(" simd_backend = \"{}\"", config.simd_backend);
println!(" memory_limit_mb = {}", config.memory_limit_mb);
println!();
println!(" [cache]");
println!(" enabled = {}", config.cache_enabled);
println!(" size_mb = {}", config.cache_size_mb);
}
fn print_summary(checks: &[DiagnosticCheck]) {
let pass_count = checks
.iter()
.filter(|c| c.status == CheckStatus::Pass)
.count();
let warn_count = checks
.iter()
.filter(|c| c.status == CheckStatus::Warning)
.count();
let fail_count = checks
.iter()
.filter(|c| c.status == CheckStatus::Fail)
.count();
println!("\n═══════════════════════════════════════════════════════════");
println!(
"📋 Summary: {} passed, {} warnings, {} failed",
pass_count, warn_count, fail_count
);
if fail_count > 0 {
println!("\n⚠️ Some checks failed. Review recommendations above.");
} else if warn_count > 0 {
println!("\n✓ System is functional with some areas for improvement.");
} else {
println!("\n✅ System is optimally configured for SciPix!");
}
}
async fn apply_fixes(checks: &[DiagnosticCheck]) -> Result<()> {
println!("\n🔧 Applying automatic fixes...");
println!("───────────────────────────────────────────────────────────");
let fixable: Vec<_> = checks.iter().filter(|c| c.auto_fixable).collect();
if fixable.is_empty() {
println!(" No automatic fixes available.");
return Ok(());
}
for check in fixable {
println!(" Fixing: {}", check.name);
if check.name == "Configuration File" {
// Create default config file
let config_content = r#"# SciPix Configuration
# Generated by scipix doctor --fix
[processing]
batch_size = 16
worker_threads = 4
simd_backend = "auto"
memory_limit_mb = 4096
[cache]
enabled = true
size_mb = 256
[api]
host = "127.0.0.1"
port = 8080
timeout_seconds = 30
[logging]
level = "info"
format = "pretty"
"#;
// Create config directory if needed
let config_path = PathBuf::from("config");
if !config_path.exists() {
std::fs::create_dir_all(&config_path)?;
}
let config_file = config_path.join("scipix.toml");
std::fs::write(&config_file, config_content)?;
println!(" ✓ Created {}", config_file.display());
}
}
Ok(())
}

View File

@@ -0,0 +1,806 @@
//! MCP (Model Context Protocol) Server Implementation for SciPix
//!
//! Implements the MCP 2025-11 specification for exposing OCR capabilities
//! as tools that can be discovered and invoked by AI hosts.
//!
//! ## Usage
//! ```bash
//! scipix-cli mcp
//! ```
//!
//! ## Protocol
//! Uses JSON-RPC 2.0 over STDIO for communication.
use clap::Args;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use std::io::{self, BufRead, Write};
use std::path::PathBuf;
/// MCP Server Arguments
#[derive(Args, Debug, Clone)]
pub struct McpArgs {
/// Enable debug logging for MCP messages
#[arg(long, help = "Enable debug logging")]
pub debug: bool,
/// Custom model path for OCR
#[arg(long, help = "Path to ONNX models directory")]
pub models_dir: Option<PathBuf>,
}
/// JSON-RPC 2.0 Request
#[derive(Debug, Deserialize)]
struct JsonRpcRequest {
#[allow(dead_code)]
jsonrpc: String,
id: Option<Value>,
method: String,
params: Option<Value>,
}
/// JSON-RPC 2.0 Response
#[derive(Debug, Serialize)]
struct JsonRpcResponse {
jsonrpc: String,
id: Value,
#[serde(skip_serializing_if = "Option::is_none")]
result: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<JsonRpcError>,
}
/// JSON-RPC 2.0 Error
#[derive(Debug, Serialize)]
struct JsonRpcError {
code: i32,
message: String,
#[serde(skip_serializing_if = "Option::is_none")]
data: Option<Value>,
}
/// MCP Server Info
#[derive(Debug, Serialize)]
struct ServerInfo {
name: String,
version: String,
}
/// MCP Server Capabilities
#[derive(Debug, Serialize)]
struct ServerCapabilities {
tools: ToolsCapability,
#[serde(skip_serializing_if = "Option::is_none")]
resources: Option<ResourcesCapability>,
}
#[derive(Debug, Serialize)]
struct ToolsCapability {
#[serde(rename = "listChanged")]
list_changed: bool,
}
#[derive(Debug, Serialize)]
struct ResourcesCapability {
subscribe: bool,
#[serde(rename = "listChanged")]
list_changed: bool,
}
/// MCP Tool Definition
#[derive(Debug, Serialize)]
struct Tool {
name: String,
description: String,
#[serde(rename = "inputSchema")]
input_schema: Value,
}
/// Tool call result
#[derive(Debug, Serialize)]
#[allow(dead_code)]
struct ToolResult {
content: Vec<ContentBlock>,
#[serde(rename = "isError", skip_serializing_if = "Option::is_none")]
is_error: Option<bool>,
}
#[derive(Debug, Serialize)]
#[allow(dead_code)]
struct ContentBlock {
#[serde(rename = "type")]
content_type: String,
text: String,
}
impl JsonRpcResponse {
fn success(id: Value, result: Value) -> Self {
Self {
jsonrpc: "2.0".to_string(),
id,
result: Some(result),
error: None,
}
}
fn error(id: Value, code: i32, message: &str) -> Self {
Self {
jsonrpc: "2.0".to_string(),
id,
result: None,
error: Some(JsonRpcError {
code,
message: message.to_string(),
data: None,
}),
}
}
}
/// MCP Server state
struct McpServer {
debug: bool,
#[allow(dead_code)]
models_dir: Option<PathBuf>,
}
impl McpServer {
fn new(args: &McpArgs) -> Self {
Self {
debug: args.debug,
models_dir: args.models_dir.clone(),
}
}
/// Get server info for initialization
fn server_info(&self) -> ServerInfo {
ServerInfo {
name: "scipix-mcp".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
}
}
/// Get server capabilities
fn capabilities(&self) -> ServerCapabilities {
ServerCapabilities {
tools: ToolsCapability {
list_changed: false,
},
resources: None,
}
}
/// Define available tools with examples following Anthropic best practices
/// See: https://www.anthropic.com/engineering/advanced-tool-use
fn get_tools(&self) -> Vec<Tool> {
vec![
Tool {
name: "ocr_image".to_string(),
description: r#"Process an image file with OCR to extract text and mathematical formulas.
WHEN TO USE: Use this tool when you have an image file path containing text, equations,
or mathematical notation that needs to be converted to a machine-readable format.
EXAMPLES:
- Extract LaTeX from a photo of a math equation: {"image_path": "equation.png", "format": "latex"}
- Get plain text from a document scan: {"image_path": "document.jpg", "format": "text"}
- Convert handwritten math to AsciiMath: {"image_path": "notes.png", "format": "asciimath"}
RETURNS: JSON with the recognized content, confidence score (0-1), and processing metadata."#.to_string(),
input_schema: json!({
"type": "object",
"properties": {
"image_path": {
"type": "string",
"description": "Absolute or relative path to image file (PNG, JPG, JPEG, GIF, BMP, TIFF supported)"
},
"format": {
"type": "string",
"enum": ["latex", "text", "mathml", "asciimath"],
"default": "latex",
"description": "Output format: 'latex' for mathematical notation, 'text' for plain text, 'mathml' for XML, 'asciimath' for simple notation"
}
},
"required": ["image_path"],
"examples": [
{"image_path": "/path/to/equation.png", "format": "latex"},
{"image_path": "document.jpg", "format": "text"}
]
}),
},
Tool {
name: "ocr_base64".to_string(),
description: r#"Process a base64-encoded image with OCR. Use when image data is inline rather than a file.
WHEN TO USE: Use this tool when you have image data as a base64 string (e.g., from an API
response, clipboard, or embedded in a document) rather than a file path.
EXAMPLES:
- Process clipboard image: {"image_data": "iVBORw0KGgo...", "format": "latex"}
- Extract text from API response image: {"image_data": "<base64_string>", "format": "text"}
NOTE: The base64 string should not include the data URI prefix (e.g., "data:image/png;base64,")."#.to_string(),
input_schema: json!({
"type": "object",
"properties": {
"image_data": {
"type": "string",
"description": "Base64-encoded image data (without data URI prefix)"
},
"format": {
"type": "string",
"enum": ["latex", "text", "mathml", "asciimath"],
"default": "latex",
"description": "Output format for recognized content"
}
},
"required": ["image_data"]
}),
},
Tool {
name: "batch_ocr".to_string(),
description: r#"Process multiple images in a directory with OCR. Efficient for bulk operations.
WHEN TO USE: Use this tool when you need to process 3+ images in the same directory.
For 1-2 images, use ocr_image instead for simpler results.
EXAMPLES:
- Process all PNGs in a folder: {"directory": "./images", "pattern": "*.png"}
- Process specific equation images: {"directory": "/docs/math", "pattern": "eq_*.jpg"}
- Get JSON results for all images: {"directory": ".", "pattern": "*.{png,jpg}", "format": "json"}
RETURNS: Array of results with file paths, recognized content, and confidence scores."#.to_string(),
input_schema: json!({
"type": "object",
"properties": {
"directory": {
"type": "string",
"description": "Directory path containing images to process"
},
"pattern": {
"type": "string",
"default": "*.png",
"description": "Glob pattern to match files (e.g., '*.png', '*.{jpg,png}', 'equation_*.jpg')"
},
"format": {
"type": "string",
"enum": ["latex", "text", "json"],
"default": "json",
"description": "Output format: 'json' for structured results (recommended), 'latex' or 'text' for concatenated output"
}
},
"required": ["directory"]
}),
},
Tool {
name: "preprocess_image".to_string(),
description: r#"Apply preprocessing operations to optimize an image for OCR.
WHEN TO USE: Use this tool BEFORE ocr_image when dealing with:
- Low contrast images (use threshold)
- Large images that need resizing (use resize)
- Color images (use grayscale for faster processing)
- Noisy or blurry images (use denoise)
EXAMPLES:
- Prepare scan for OCR: {"image_path": "scan.jpg", "output_path": "scan_clean.png", "operations": ["grayscale", "threshold"]}
- Resize large image: {"image_path": "photo.jpg", "output_path": "photo_small.png", "operations": ["resize"], "target_width": 800}
WORKFLOW: preprocess_image -> ocr_image for best results on problematic images."#.to_string(),
input_schema: json!({
"type": "object",
"properties": {
"image_path": {
"type": "string",
"description": "Path to input image file"
},
"output_path": {
"type": "string",
"description": "Path for preprocessed output image"
},
"operations": {
"type": "array",
"items": {
"type": "string",
"enum": ["grayscale", "resize", "threshold", "denoise", "deskew"]
},
"default": ["grayscale", "resize"],
"description": "Operations to apply in order: grayscale (convert to B&W), resize (scale to target size), threshold (binarize), denoise (reduce noise), deskew (straighten)"
},
"target_width": {
"type": "integer",
"default": 640,
"description": "Target width for resize (preserves aspect ratio)"
},
"target_height": {
"type": "integer",
"default": 480,
"description": "Target height for resize (preserves aspect ratio)"
}
},
"required": ["image_path", "output_path"]
}),
},
Tool {
name: "latex_to_mathml".to_string(),
description: r#"Convert LaTeX mathematical notation to MathML XML format.
WHEN TO USE: Use this tool when you need MathML output from LaTeX, such as:
- Generating accessible math content for web pages
- Converting equations for screen readers
- Integrating with systems that require MathML
EXAMPLES:
- Convert fraction: {"latex": "\\frac{1}{2}"}
- Convert integral: {"latex": "\\int_0^1 x^2 dx"}
- Convert matrix: {"latex": "\\begin{pmatrix} a & b \\\\ c & d \\end{pmatrix}"}"#.to_string(),
input_schema: json!({
"type": "object",
"properties": {
"latex": {
"type": "string",
"description": "LaTeX expression to convert (with or without $ delimiters)"
}
},
"required": ["latex"],
"examples": [
{"latex": "\\frac{a}{b}"},
{"latex": "E = mc^2"}
]
}),
},
Tool {
name: "benchmark_performance".to_string(),
description: r#"Run performance benchmarks on the OCR pipeline and return timing metrics.
WHEN TO USE: Use this tool to:
- Verify OCR performance on your system
- Compare preprocessing options
- Debug slow processing issues
EXAMPLES:
- Quick performance check: {"iterations": 5}
- Test specific image: {"image_path": "test.png", "iterations": 10}
RETURNS: Average processing times for grayscale, resize operations, and system info."#.to_string(),
input_schema: json!({
"type": "object",
"properties": {
"iterations": {
"type": "integer",
"default": 10,
"minimum": 1,
"maximum": 100,
"description": "Number of benchmark iterations (higher = more accurate, slower)"
},
"image_path": {
"type": "string",
"description": "Optional: Path to test image (uses generated test image if not provided)"
}
}
}),
},
]
}
/// Handle incoming JSON-RPC request
async fn handle_request(&self, request: JsonRpcRequest) -> JsonRpcResponse {
let id = request.id.unwrap_or(Value::Null);
if self.debug {
eprintln!("[MCP DEBUG] Method: {}", request.method);
if let Some(ref params) = request.params {
eprintln!(
"[MCP DEBUG] Params: {}",
serde_json::to_string_pretty(params).unwrap_or_default()
);
}
}
match request.method.as_str() {
"initialize" => self.handle_initialize(id, request.params),
"initialized" => JsonRpcResponse::success(id, json!({})),
"tools/list" => self.handle_tools_list(id),
"tools/call" => self.handle_tools_call(id, request.params).await,
"ping" => JsonRpcResponse::success(id, json!({})),
"shutdown" => {
std::process::exit(0);
}
_ => {
JsonRpcResponse::error(id, -32601, &format!("Method not found: {}", request.method))
}
}
}
/// Handle initialize request
fn handle_initialize(&self, id: Value, params: Option<Value>) -> JsonRpcResponse {
if self.debug {
if let Some(p) = &params {
eprintln!(
"[MCP DEBUG] Client info: {}",
serde_json::to_string_pretty(p).unwrap_or_default()
);
}
}
JsonRpcResponse::success(
id,
json!({
"protocolVersion": "2024-11-05",
"serverInfo": self.server_info(),
"capabilities": self.capabilities()
}),
)
}
/// Handle tools/list request
fn handle_tools_list(&self, id: Value) -> JsonRpcResponse {
JsonRpcResponse::success(
id,
json!({
"tools": self.get_tools()
}),
)
}
/// Handle tools/call request
async fn handle_tools_call(&self, id: Value, params: Option<Value>) -> JsonRpcResponse {
let params = match params {
Some(p) => p,
None => return JsonRpcResponse::error(id, -32602, "Missing params"),
};
let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
let arguments = params.get("arguments").cloned().unwrap_or(json!({}));
if self.debug {
eprintln!(
"[MCP DEBUG] Tool call: {} with args: {}",
tool_name, arguments
);
}
let result = match tool_name {
"ocr_image" => self.call_ocr_image(&arguments).await,
"ocr_base64" => self.call_ocr_base64(&arguments).await,
"batch_ocr" => self.call_batch_ocr(&arguments).await,
"preprocess_image" => self.call_preprocess_image(&arguments).await,
"latex_to_mathml" => self.call_latex_to_mathml(&arguments).await,
"benchmark_performance" => self.call_benchmark(&arguments).await,
_ => Err(format!("Unknown tool: {}", tool_name)),
};
match result {
Ok(content) => JsonRpcResponse::success(
id,
json!({
"content": [{
"type": "text",
"text": content
}]
}),
),
Err(e) => JsonRpcResponse::success(
id,
json!({
"content": [{
"type": "text",
"text": e
}],
"isError": true
}),
),
}
}
/// OCR image file
async fn call_ocr_image(&self, args: &Value) -> Result<String, String> {
let image_path = args
.get("image_path")
.and_then(|p| p.as_str())
.ok_or("Missing image_path parameter")?;
let format = args
.get("format")
.and_then(|f| f.as_str())
.unwrap_or("latex");
// Check if file exists
if !std::path::Path::new(image_path).exists() {
return Err(format!("Image file not found: {}", image_path));
}
// Load and process image
let img = image::open(image_path).map_err(|e| format!("Failed to load image: {}", e))?;
// Perform OCR (using mock for now, real inference when models are available)
let result = self.perform_ocr(&img, format).await?;
Ok(serde_json::to_string_pretty(&json!({
"file": image_path,
"format": format,
"result": result,
"confidence": 0.95
}))
.unwrap_or_default())
}
/// OCR base64 image
async fn call_ocr_base64(&self, args: &Value) -> Result<String, String> {
let image_data = args
.get("image_data")
.and_then(|d| d.as_str())
.ok_or("Missing image_data parameter")?;
let format = args
.get("format")
.and_then(|f| f.as_str())
.unwrap_or("latex");
// Decode base64
let decoded =
base64::Engine::decode(&base64::engine::general_purpose::STANDARD, image_data)
.map_err(|e| format!("Invalid base64 data: {}", e))?;
// Load image from bytes
let img = image::load_from_memory(&decoded)
.map_err(|e| format!("Failed to load image from data: {}", e))?;
// Perform OCR
let result = self.perform_ocr(&img, format).await?;
Ok(serde_json::to_string_pretty(&json!({
"format": format,
"result": result,
"confidence": 0.95
}))
.unwrap_or_default())
}
/// Batch OCR processing
async fn call_batch_ocr(&self, args: &Value) -> Result<String, String> {
let directory = args
.get("directory")
.and_then(|d| d.as_str())
.ok_or("Missing directory parameter")?;
let pattern = args
.get("pattern")
.and_then(|p| p.as_str())
.unwrap_or("*.png");
let format = args
.get("format")
.and_then(|f| f.as_str())
.unwrap_or("json");
// Find files matching pattern
let glob_pattern = format!("{}/{}", directory, pattern);
let paths: Vec<_> = glob::glob(&glob_pattern)
.map_err(|e| format!("Invalid glob pattern: {}", e))?
.filter_map(|p| p.ok())
.collect();
let mut results = Vec::new();
for path in &paths {
let img = match image::open(path) {
Ok(img) => img,
Err(e) => {
results.push(json!({
"file": path.display().to_string(),
"error": e.to_string()
}));
continue;
}
};
let ocr_result = self.perform_ocr(&img, format).await.unwrap_or_else(|e| e);
results.push(json!({
"file": path.display().to_string(),
"result": ocr_result,
"confidence": 0.95
}));
}
Ok(serde_json::to_string_pretty(&json!({
"total": paths.len(),
"processed": results.len(),
"results": results
}))
.unwrap_or_default())
}
/// Preprocess image
async fn call_preprocess_image(&self, args: &Value) -> Result<String, String> {
let image_path = args
.get("image_path")
.and_then(|p| p.as_str())
.ok_or("Missing image_path parameter")?;
let output_path = args
.get("output_path")
.and_then(|p| p.as_str())
.ok_or("Missing output_path parameter")?;
let operations: Vec<&str> = args
.get("operations")
.and_then(|o| o.as_array())
.map(|arr| arr.iter().filter_map(|v| v.as_str()).collect())
.unwrap_or_else(|| vec!["grayscale", "resize"]);
// Load image
let mut img =
image::open(image_path).map_err(|e| format!("Failed to load image: {}", e))?;
// Apply operations
for op in &operations {
match *op {
"grayscale" => {
img = image::DynamicImage::ImageLuma8(img.to_luma8());
}
"resize" => {
let width = args
.get("target_width")
.and_then(|w| w.as_u64())
.unwrap_or(640) as u32;
let height = args
.get("target_height")
.and_then(|h| h.as_u64())
.unwrap_or(480) as u32;
img = img.resize(width, height, image::imageops::FilterType::Lanczos3);
}
_ => {}
}
}
// Save output
img.save(output_path)
.map_err(|e| format!("Failed to save image: {}", e))?;
Ok(serde_json::to_string_pretty(&json!({
"input": image_path,
"output": output_path,
"operations": operations,
"dimensions": {
"width": img.width(),
"height": img.height()
}
}))
.unwrap_or_default())
}
/// Convert LaTeX to MathML
async fn call_latex_to_mathml(&self, args: &Value) -> Result<String, String> {
let latex = args
.get("latex")
.and_then(|l| l.as_str())
.ok_or("Missing latex parameter")?;
// Simple LaTeX to MathML conversion (placeholder)
let mathml = format!(
r#"<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi>{}</mi></mrow></math>"#,
latex.replace("\\", "").replace("{", "").replace("}", "")
);
Ok(serde_json::to_string_pretty(&json!({
"latex": latex,
"mathml": mathml
}))
.unwrap_or_default())
}
/// Run performance benchmark
async fn call_benchmark(&self, args: &Value) -> Result<String, String> {
let iterations = args
.get("iterations")
.and_then(|i| i.as_u64())
.unwrap_or(10) as usize;
use std::time::Instant;
// Generate test image
let test_img =
image::DynamicImage::ImageRgb8(image::ImageBuffer::from_fn(400, 100, |_, _| {
image::Rgb([255u8, 255u8, 255u8])
}));
// Benchmark preprocessing
let start = Instant::now();
for _ in 0..iterations {
let _gray = test_img.to_luma8();
}
let grayscale_time = start.elapsed() / iterations as u32;
let start = Instant::now();
for _ in 0..iterations {
let _resized = test_img.resize(640, 480, image::imageops::FilterType::Nearest);
}
let resize_time = start.elapsed() / iterations as u32;
Ok(serde_json::to_string_pretty(&json!({
"iterations": iterations,
"benchmarks": {
"grayscale_avg_ms": grayscale_time.as_secs_f64() * 1000.0,
"resize_avg_ms": resize_time.as_secs_f64() * 1000.0,
},
"system": {
"cpu_cores": num_cpus::get()
}
}))
.unwrap_or_default())
}
/// Perform OCR on image (placeholder implementation)
async fn perform_ocr(
&self,
_img: &image::DynamicImage,
format: &str,
) -> Result<String, String> {
// This is a placeholder - in production, this would call the actual OCR engine
let result = match format {
"latex" => r"\int_0^1 x^2 \, dx = \frac{1}{3}".to_string(),
"text" => "Sample OCR extracted text".to_string(),
"mathml" => r#"<math><mrow><mi>x</mi><mo>=</mo><mn>2</mn></mrow></math>"#.to_string(),
"asciimath" => "int_0^1 x^2 dx = 1/3".to_string(),
_ => "Unknown format".to_string(),
};
Ok(result)
}
}
/// Run the MCP server
pub async fn run(args: McpArgs) -> anyhow::Result<()> {
let server = McpServer::new(&args);
if args.debug {
eprintln!("[MCP] SciPix MCP Server starting...");
eprintln!("[MCP] Version: {}", env!("CARGO_PKG_VERSION"));
}
let stdin = io::stdin();
let mut stdout = io::stdout();
for line in stdin.lock().lines() {
let line = match line {
Ok(l) => l,
Err(e) => {
if args.debug {
eprintln!("[MCP ERROR] Failed to read stdin: {}", e);
}
continue;
}
};
if line.trim().is_empty() {
continue;
}
if args.debug {
eprintln!("[MCP DEBUG] Received: {}", line);
}
let request: JsonRpcRequest = match serde_json::from_str(&line) {
Ok(req) => req,
Err(e) => {
let error_response =
JsonRpcResponse::error(Value::Null, -32700, &format!("Parse error: {}", e));
let output = serde_json::to_string(&error_response).unwrap_or_default();
writeln!(stdout, "{}", output)?;
stdout.flush()?;
continue;
}
};
let response = server.handle_request(request).await;
let output = serde_json::to_string(&response)?;
if args.debug {
eprintln!("[MCP DEBUG] Response: {}", output);
}
writeln!(stdout, "{}", output)?;
stdout.flush()?;
}
Ok(())
}

View File

@@ -0,0 +1,99 @@
pub mod batch;
pub mod config;
pub mod doctor;
pub mod mcp;
pub mod ocr;
pub mod serve;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
/// Common result structure for OCR operations
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrResult {
/// Source file path
pub file: PathBuf,
/// Extracted text content
pub text: String,
/// LaTeX representation (if available)
pub latex: Option<String>,
/// Confidence score (0.0 to 1.0)
pub confidence: f64,
/// Processing time in milliseconds
pub processing_time_ms: u64,
/// Any errors or warnings
pub errors: Vec<String>,
}
impl OcrResult {
/// Create a new OCR result
pub fn new(file: PathBuf, text: String, confidence: f64) -> Self {
Self {
file,
text,
latex: None,
confidence,
processing_time_ms: 0,
errors: Vec::new(),
}
}
/// Set LaTeX content
pub fn with_latex(mut self, latex: String) -> Self {
self.latex = Some(latex);
self
}
/// Set processing time
pub fn with_processing_time(mut self, time_ms: u64) -> Self {
self.processing_time_ms = time_ms;
self
}
/// Add an error message
pub fn add_error(&mut self, error: String) {
self.errors.push(error);
}
}
/// Configuration for OCR processing
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrConfig {
/// Minimum confidence threshold
pub min_confidence: f64,
/// Maximum image size in bytes
pub max_image_size: usize,
/// Supported file extensions
pub supported_extensions: Vec<String>,
/// API endpoint (if using remote service)
pub api_endpoint: Option<String>,
/// API key (if using remote service)
pub api_key: Option<String>,
}
impl Default for OcrConfig {
fn default() -> Self {
Self {
min_confidence: 0.7,
max_image_size: 10 * 1024 * 1024, // 10MB
supported_extensions: vec![
"png".to_string(),
"jpg".to_string(),
"jpeg".to_string(),
"pdf".to_string(),
"gif".to_string(),
],
api_endpoint: None,
api_key: None,
}
}
}

View File

@@ -0,0 +1,210 @@
use anyhow::{Context, Result};
use clap::Args;
use std::path::PathBuf;
use std::time::Instant;
use tracing::{debug, info};
use super::{OcrConfig, OcrResult};
use crate::cli::{output, Cli, OutputFormat};
/// Process a single image or file with OCR
#[derive(Args, Debug, Clone)]
pub struct OcrArgs {
/// Path to the image file to process
#[arg(value_name = "FILE", help = "Path to the image file")]
pub file: PathBuf,
/// Minimum confidence threshold (0.0 to 1.0)
#[arg(
short = 't',
long,
default_value = "0.7",
help = "Minimum confidence threshold for results"
)]
pub threshold: f64,
/// Save output to file instead of stdout
#[arg(
short,
long,
value_name = "OUTPUT",
help = "Save output to file instead of stdout"
)]
pub output: Option<PathBuf>,
/// Pretty-print JSON output
#[arg(
short,
long,
help = "Pretty-print JSON output (only with --format json)"
)]
pub pretty: bool,
/// Include metadata in output
#[arg(short, long, help = "Include processing metadata in output")]
pub metadata: bool,
/// Force processing even if confidence is below threshold
#[arg(
short = 'f',
long,
help = "Force processing even if confidence is below threshold"
)]
pub force: bool,
}
pub async fn execute(args: OcrArgs, cli: &Cli) -> Result<()> {
info!("Processing file: {}", args.file.display());
// Validate input file
if !args.file.exists() {
anyhow::bail!("File not found: {}", args.file.display());
}
if !args.file.is_file() {
anyhow::bail!("Not a file: {}", args.file.display());
}
// Load configuration
let config = load_config(cli.config.as_ref())?;
// Validate file extension
if let Some(ext) = args.file.extension() {
let ext_str = ext.to_string_lossy().to_lowercase();
if !config.supported_extensions.contains(&ext_str) {
anyhow::bail!(
"Unsupported file extension: {}. Supported: {}",
ext_str,
config.supported_extensions.join(", ")
);
}
} else {
anyhow::bail!("File has no extension");
}
// Check file size
let metadata = std::fs::metadata(&args.file).context("Failed to read file metadata")?;
if metadata.len() as usize > config.max_image_size {
anyhow::bail!(
"File too large: {} bytes (max: {} bytes)",
metadata.len(),
config.max_image_size
);
}
// Process the file
let start = Instant::now();
let result = process_file(&args.file, &config).await?;
let processing_time = start.elapsed();
debug!("Processing completed in {:?}", processing_time);
// Check confidence threshold
if result.confidence < args.threshold && !args.force {
anyhow::bail!(
"Confidence {} is below threshold {} (use --force to override)",
result.confidence,
args.threshold
);
}
// Format and output result
let output_content = format_result(&result, &cli.format, args.pretty, args.metadata)?;
if let Some(output_path) = &args.output {
std::fs::write(output_path, &output_content).context("Failed to write output file")?;
info!("Output saved to: {}", output_path.display());
} else {
println!("{}", output_content);
}
// Display summary if not quiet
if !cli.quiet {
output::print_ocr_summary(&result);
}
Ok(())
}
async fn process_file(file: &PathBuf, _config: &OcrConfig) -> Result<OcrResult> {
// TODO: Implement actual OCR processing
// For now, return a mock result
let start = Instant::now();
// Simulate OCR processing
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let processing_time = start.elapsed().as_millis() as u64;
Ok(OcrResult {
file: file.clone(),
text: "Sample OCR text from image".to_string(),
latex: Some(r"\int_0^1 x^2 \, dx = \frac{1}{3}".to_string()),
confidence: 0.95,
processing_time_ms: processing_time,
errors: Vec::new(),
})
}
fn format_result(
result: &OcrResult,
format: &OutputFormat,
pretty: bool,
include_metadata: bool,
) -> Result<String> {
match format {
OutputFormat::Json => if include_metadata {
if pretty {
serde_json::to_string_pretty(result)
} else {
serde_json::to_string(result)
}
} else {
let simple = serde_json::json!({
"text": result.text,
"latex": result.latex,
"confidence": result.confidence,
});
if pretty {
serde_json::to_string_pretty(&simple)
} else {
serde_json::to_string(&simple)
}
}
.context("Failed to serialize to JSON"),
OutputFormat::Text => Ok(result.text.clone()),
OutputFormat::Latex => Ok(result.latex.clone().unwrap_or_else(|| result.text.clone())),
OutputFormat::Markdown => {
let mut md = format!("# OCR Result\n\n{}\n", result.text);
if let Some(latex) = &result.latex {
md.push_str(&format!("\n## LaTeX\n\n```latex\n{}\n```\n", latex));
}
if include_metadata {
md.push_str(&format!(
"\n---\n\nConfidence: {:.2}%\nProcessing time: {}ms\n",
result.confidence * 100.0,
result.processing_time_ms
));
}
Ok(md)
}
OutputFormat::MathMl => {
// TODO: Implement MathML conversion
Ok(format!(
"<math xmlns=\"http://www.w3.org/1998/Math/MathML\">\n {}\n</math>",
result.text
))
}
}
}
fn load_config(config_path: Option<&PathBuf>) -> Result<OcrConfig> {
if let Some(path) = config_path {
let content = std::fs::read_to_string(path).context("Failed to read config file")?;
toml::from_str(&content).context("Failed to parse config file")
} else {
Ok(OcrConfig::default())
}
}

View File

@@ -0,0 +1,293 @@
use anyhow::{Context, Result};
use axum::{
extract::{Multipart, State},
http::StatusCode,
response::IntoResponse,
routing::{get, post},
Json, Router,
};
use clap::Args;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::signal;
use tower_http::{cors::CorsLayer, trace::TraceLayer};
use tracing::{info, warn};
use super::{OcrConfig, OcrResult};
use crate::cli::Cli;
/// Start the API server
#[derive(Args, Debug, Clone)]
pub struct ServeArgs {
/// Port to listen on
#[arg(
short,
long,
default_value = "8080",
env = "MATHPIX_PORT",
help = "Port to listen on"
)]
pub port: u16,
/// Host to bind to
#[arg(
short = 'H',
long,
default_value = "127.0.0.1",
env = "MATHPIX_HOST",
help = "Host address to bind to"
)]
pub host: String,
/// Directory containing ML models
#[arg(
long,
value_name = "DIR",
help = "Directory containing ML models to preload"
)]
pub model_dir: Option<PathBuf>,
/// Enable CORS
#[arg(long, help = "Enable CORS for cross-origin requests")]
pub cors: bool,
/// Maximum request size in MB
#[arg(long, default_value = "10", help = "Maximum request size in megabytes")]
pub max_size: usize,
/// Number of worker threads
#[arg(
short = 'w',
long,
default_value = "4",
help = "Number of worker threads"
)]
pub workers: usize,
}
#[derive(Clone)]
struct AppState {
config: Arc<OcrConfig>,
max_size: usize,
}
pub async fn execute(args: ServeArgs, cli: &Cli) -> Result<()> {
info!("Starting Scipix API server");
// Load configuration
let config = Arc::new(load_config(cli.config.as_ref())?);
// Preload models if specified
if let Some(model_dir) = &args.model_dir {
info!("Preloading models from: {}", model_dir.display());
preload_models(model_dir)?;
}
// Create app state
let state = AppState {
config,
max_size: args.max_size * 1024 * 1024,
};
// Build router
let mut app = Router::new()
.route("/", get(root))
.route("/health", get(health))
.route("/api/v1/ocr", post(ocr_handler))
.route("/api/v1/batch", post(batch_handler))
.with_state(state)
.layer(TraceLayer::new_for_http());
// Add CORS if enabled
if args.cors {
app = app.layer(CorsLayer::permissive());
info!("CORS enabled");
}
// Create socket address
let addr: SocketAddr = format!("{}:{}", args.host, args.port)
.parse()
.context("Invalid host/port combination")?;
info!("Server listening on http://{}", addr);
info!("API endpoints:");
info!(" POST http://{}/api/v1/ocr - Single file OCR", addr);
info!(" POST http://{}/api/v1/batch - Batch processing", addr);
info!(" GET http://{}/health - Health check", addr);
// Create server
let listener = tokio::net::TcpListener::bind(addr)
.await
.context("Failed to bind to address")?;
// Run server with graceful shutdown
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await
.context("Server error")?;
info!("Server shutdown complete");
Ok(())
}
async fn root() -> &'static str {
"Scipix OCR API Server\n\nEndpoints:\n POST /api/v1/ocr - Single file OCR\n POST /api/v1/batch - Batch processing\n GET /health - Health check"
}
async fn health() -> impl IntoResponse {
Json(serde_json::json!({
"status": "healthy",
"version": env!("CARGO_PKG_VERSION"),
}))
}
async fn ocr_handler(
State(state): State<AppState>,
mut multipart: Multipart,
) -> Result<Json<OcrResult>, (StatusCode, String)> {
while let Some(field) = multipart
.next_field()
.await
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?
{
let name = field.name().unwrap_or("").to_string();
if name == "file" {
let data = field
.bytes()
.await
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
if data.len() > state.max_size {
return Err((
StatusCode::PAYLOAD_TOO_LARGE,
format!(
"File too large: {} bytes (max: {} bytes)",
data.len(),
state.max_size
),
));
}
// Process the file
let result = process_image_data(&data, &state.config)
.await
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
return Ok(Json(result));
}
}
Err((StatusCode::BAD_REQUEST, "No file provided".to_string()))
}
async fn batch_handler(
State(state): State<AppState>,
mut multipart: Multipart,
) -> Result<Json<Vec<OcrResult>>, (StatusCode, String)> {
let mut results = Vec::new();
while let Some(field) = multipart
.next_field()
.await
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?
{
let name = field.name().unwrap_or("").to_string();
if name == "files" {
let data = field
.bytes()
.await
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
if data.len() > state.max_size {
warn!("Skipping file: too large ({} bytes)", data.len());
continue;
}
// Process the file
match process_image_data(&data, &state.config).await {
Ok(result) => results.push(result),
Err(e) => warn!("Failed to process file: {}", e),
}
}
}
if results.is_empty() {
return Err((
StatusCode::BAD_REQUEST,
"No valid files processed".to_string(),
));
}
Ok(Json(results))
}
async fn process_image_data(data: &[u8], _config: &OcrConfig) -> Result<OcrResult> {
// TODO: Implement actual OCR processing
// For now, return a mock result
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
Ok(OcrResult {
file: PathBuf::from("uploaded_file"),
text: format!("OCR text from uploaded image ({} bytes)", data.len()),
latex: Some(r"\text{Sample LaTeX}".to_string()),
confidence: 0.92,
processing_time_ms: 50,
errors: Vec::new(),
})
}
fn preload_models(model_dir: &PathBuf) -> Result<()> {
if !model_dir.exists() {
anyhow::bail!("Model directory not found: {}", model_dir.display());
}
if !model_dir.is_dir() {
anyhow::bail!("Not a directory: {}", model_dir.display());
}
// TODO: Implement model preloading
info!("Models preloaded from {}", model_dir.display());
Ok(())
}
fn load_config(config_path: Option<&PathBuf>) -> Result<OcrConfig> {
if let Some(path) = config_path {
let content = std::fs::read_to_string(path).context("Failed to read config file")?;
toml::from_str(&content).context("Failed to parse config file")
} else {
Ok(OcrConfig::default())
}
}
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {
info!("Received Ctrl+C signal");
},
_ = terminate => {
info!("Received terminate signal");
},
}
}

View File

@@ -0,0 +1,115 @@
pub mod commands;
pub mod output;
use clap::{Parser, Subcommand};
use std::path::PathBuf;
/// Scipix CLI - OCR and mathematical content processing
#[derive(Parser, Debug)]
#[command(
name = "scipix-cli",
version,
about = "A Rust-based CLI for Scipix OCR processing",
long_about = "Process images with OCR, extract mathematical formulas, and convert to LaTeX or other formats.\n\n\
Supports single file processing, batch operations, and API server mode."
)]
pub struct Cli {
/// Path to configuration file
#[arg(
short,
long,
global = true,
env = "MATHPIX_CONFIG",
help = "Path to configuration file"
)]
pub config: Option<PathBuf>,
/// Enable verbose logging
#[arg(
short,
long,
global = true,
help = "Enable verbose logging (DEBUG level)"
)]
pub verbose: bool,
/// Suppress all non-error output
#[arg(
short,
long,
global = true,
conflicts_with = "verbose",
help = "Suppress all non-error output"
)]
pub quiet: bool,
/// Output format (json, text, latex, markdown)
#[arg(
short,
long,
global = true,
default_value = "text",
help = "Output format for results"
)]
pub format: OutputFormat,
#[command(subcommand)]
pub command: Commands,
}
#[derive(Subcommand, Debug)]
pub enum Commands {
/// Process a single image or file with OCR
Ocr(commands::ocr::OcrArgs),
/// Process multiple files in batch mode
Batch(commands::batch::BatchArgs),
/// Start the API server
Serve(commands::serve::ServeArgs),
/// Start the MCP (Model Context Protocol) server for AI integration
Mcp(commands::mcp::McpArgs),
/// Manage configuration
Config(commands::config::ConfigArgs),
/// Diagnose environment and optimize configuration
Doctor(commands::doctor::DoctorArgs),
/// Show version information
Version,
/// Generate shell completions
Completions {
/// Shell to generate completions for (bash, zsh, fish, powershell)
#[arg(value_enum)]
shell: Option<clap_complete::Shell>,
},
}
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
pub enum OutputFormat {
/// Plain text output
Text,
/// JSON output
Json,
/// LaTeX format
Latex,
/// Markdown format
Markdown,
/// MathML format
MathMl,
}
impl std::fmt::Display for OutputFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
OutputFormat::Text => write!(f, "text"),
OutputFormat::Json => write!(f, "json"),
OutputFormat::Latex => write!(f, "latex"),
OutputFormat::Markdown => write!(f, "markdown"),
OutputFormat::MathMl => write!(f, "mathml"),
}
}
}

View File

@@ -0,0 +1,223 @@
use comfy_table::{modifiers::UTF8_ROUND_CORNERS, presets::UTF8_FULL, Cell, Color, Table};
use console::style;
use super::commands::OcrResult;
/// Print a summary of a single OCR result
pub fn print_ocr_summary(result: &OcrResult) {
println!("\n{}", style("OCR Processing Summary").bold().cyan());
println!("{}", style("".repeat(60)).dim());
let mut table = Table::new();
table
.load_preset(UTF8_FULL)
.apply_modifier(UTF8_ROUND_CORNERS)
.set_header(vec![
Cell::new("Property").fg(Color::Cyan),
Cell::new("Value").fg(Color::Green),
]);
table.add_row(vec![
Cell::new("File"),
Cell::new(result.file.display().to_string()),
]);
table.add_row(vec![
Cell::new("Confidence"),
Cell::new(format!("{:.2}%", result.confidence * 100.0))
.fg(confidence_color(result.confidence)),
]);
table.add_row(vec![
Cell::new("Processing Time"),
Cell::new(format!("{}ms", result.processing_time_ms)),
]);
if let Some(latex) = &result.latex {
table.add_row(vec![
Cell::new("LaTeX"),
Cell::new(if latex.len() > 50 {
format!("{}...", &latex[..50])
} else {
latex.clone()
}),
]);
}
if !result.errors.is_empty() {
table.add_row(vec![
Cell::new("Errors").fg(Color::Red),
Cell::new(result.errors.len().to_string()).fg(Color::Red),
]);
}
println!("{table}");
if !result.errors.is_empty() {
println!("\n{}", style("Errors:").bold().red());
for (i, error) in result.errors.iter().enumerate() {
println!(" {}. {}", i + 1, style(error).red());
}
}
println!();
}
/// Print a summary of batch processing results
pub fn print_batch_summary(passed: &[OcrResult], failed: &[OcrResult], threshold: f64) {
println!("\n{}", style("Batch Processing Summary").bold().cyan());
println!("{}", style("".repeat(60)).dim());
let total = passed.len() + failed.len();
let avg_confidence = if !passed.is_empty() {
passed.iter().map(|r| r.confidence).sum::<f64>() / passed.len() as f64
} else {
0.0
};
let total_time: u64 = passed.iter().map(|r| r.processing_time_ms).sum();
let avg_time = if !passed.is_empty() {
total_time / passed.len() as u64
} else {
0
};
let mut table = Table::new();
table
.load_preset(UTF8_FULL)
.apply_modifier(UTF8_ROUND_CORNERS)
.set_header(vec![
Cell::new("Metric").fg(Color::Cyan),
Cell::new("Value").fg(Color::Green),
]);
table.add_row(vec![Cell::new("Total Files"), Cell::new(total.to_string())]);
table.add_row(vec![
Cell::new("Passed").fg(Color::Green),
Cell::new(format!(
"{} ({:.1}%)",
passed.len(),
(passed.len() as f64 / total as f64) * 100.0
))
.fg(Color::Green),
]);
table.add_row(vec![
Cell::new("Failed").fg(Color::Red),
Cell::new(format!(
"{} ({:.1}%)",
failed.len(),
(failed.len() as f64 / total as f64) * 100.0
))
.fg(if failed.is_empty() {
Color::Green
} else {
Color::Red
}),
]);
table.add_row(vec![
Cell::new("Threshold"),
Cell::new(format!("{:.2}%", threshold * 100.0)),
]);
table.add_row(vec![
Cell::new("Avg Confidence"),
Cell::new(format!("{:.2}%", avg_confidence * 100.0)).fg(confidence_color(avg_confidence)),
]);
table.add_row(vec![
Cell::new("Avg Processing Time"),
Cell::new(format!("{}ms", avg_time)),
]);
table.add_row(vec![
Cell::new("Total Processing Time"),
Cell::new(format!("{:.2}s", total_time as f64 / 1000.0)),
]);
println!("{table}");
if !failed.is_empty() {
println!("\n{}", style("Failed Files:").bold().red());
let mut failed_table = Table::new();
failed_table
.load_preset(UTF8_FULL)
.apply_modifier(UTF8_ROUND_CORNERS)
.set_header(vec![
Cell::new("#").fg(Color::Cyan),
Cell::new("File").fg(Color::Cyan),
Cell::new("Confidence").fg(Color::Cyan),
]);
for (i, result) in failed.iter().enumerate() {
failed_table.add_row(vec![
Cell::new((i + 1).to_string()),
Cell::new(result.file.display().to_string()),
Cell::new(format!("{:.2}%", result.confidence * 100.0)).fg(Color::Red),
]);
}
println!("{failed_table}");
}
// Summary statistics
println!("\n{}", style("Statistics:").bold().cyan());
if !passed.is_empty() {
let confidences: Vec<f64> = passed.iter().map(|r| r.confidence).collect();
let min_confidence = confidences.iter().cloned().fold(f64::INFINITY, f64::min);
let max_confidence = confidences
.iter()
.cloned()
.fold(f64::NEG_INFINITY, f64::max);
println!(
" Min confidence: {}",
style(format!("{:.2}%", min_confidence * 100.0)).green()
);
println!(
" Max confidence: {}",
style(format!("{:.2}%", max_confidence * 100.0)).green()
);
let times: Vec<u64> = passed.iter().map(|r| r.processing_time_ms).collect();
let min_time = times.iter().min().unwrap_or(&0);
let max_time = times.iter().max().unwrap_or(&0);
println!(" Min processing time: {}ms", style(min_time).cyan());
println!(" Max processing time: {}ms", style(max_time).cyan());
}
println!();
}
/// Get color based on confidence value
fn confidence_color(confidence: f64) -> Color {
if confidence >= 0.9 {
Color::Green
} else if confidence >= 0.7 {
Color::Yellow
} else {
Color::Red
}
}
/// Create a progress bar style for batch processing
pub fn create_progress_style() -> indicatif::ProgressStyle {
indicatif::ProgressStyle::default_bar()
.template(
"{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta}) {msg}",
)
.unwrap()
.progress_chars("█▓▒░ ")
}
/// Create a spinner style for individual file processing
pub fn create_spinner_style() -> indicatif::ProgressStyle {
indicatif::ProgressStyle::default_spinner()
.template("{spinner:.cyan} {msg}")
.unwrap()
.tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏")
}

View File

@@ -0,0 +1,455 @@
//! Configuration system for Ruvector-Scipix
//!
//! Comprehensive configuration with TOML support, environment overrides, and validation.
use crate::error::{Result, ScipixError};
use serde::{Deserialize, Serialize};
use std::path::Path;
/// Main configuration structure
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
/// OCR processing configuration
pub ocr: OcrConfig,
/// Model configuration
pub model: ModelConfig,
/// Preprocessing configuration
pub preprocess: PreprocessConfig,
/// Output format configuration
pub output: OutputConfig,
/// Performance tuning
pub performance: PerformanceConfig,
/// Cache configuration
pub cache: CacheConfig,
}
/// OCR engine configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrConfig {
/// Confidence threshold (0.0-1.0)
pub confidence_threshold: f32,
/// Maximum processing time in seconds
pub timeout: u64,
/// Enable GPU acceleration
pub use_gpu: bool,
/// Language codes (e.g., ["en", "es"])
pub languages: Vec<String>,
/// Enable equation detection
pub detect_equations: bool,
/// Enable table detection
pub detect_tables: bool,
}
/// Model configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
/// Path to OCR model
pub model_path: String,
/// Model version
pub version: String,
/// Batch size for processing
pub batch_size: usize,
/// Model precision (fp16, fp32, int8)
pub precision: String,
/// Enable quantization
pub quantize: bool,
}
/// Image preprocessing configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PreprocessConfig {
/// Enable auto-rotation
pub auto_rotate: bool,
/// Enable denoising
pub denoise: bool,
/// Enable contrast enhancement
pub enhance_contrast: bool,
/// Enable binarization
pub binarize: bool,
/// Target DPI for scaling
pub target_dpi: u32,
/// Maximum image dimension
pub max_dimension: u32,
}
/// Output format configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OutputConfig {
/// Output formats (latex, mathml, asciimath)
pub formats: Vec<String>,
/// Include confidence scores
pub include_confidence: bool,
/// Include bounding boxes
pub include_bbox: bool,
/// Pretty print JSON
pub pretty_print: bool,
/// Include metadata
pub include_metadata: bool,
}
/// Performance tuning configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceConfig {
/// Number of worker threads
pub num_threads: usize,
/// Enable parallel processing
pub parallel: bool,
/// Memory limit in MB
pub memory_limit: usize,
/// Enable profiling
pub profile: bool,
}
/// Cache configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheConfig {
/// Enable caching
pub enabled: bool,
/// Cache capacity (number of entries)
pub capacity: usize,
/// Similarity threshold for cache hits (0.0-1.0)
pub similarity_threshold: f32,
/// Cache TTL in seconds
pub ttl: u64,
/// Vector dimension for embeddings
pub vector_dimension: usize,
/// Enable persistent cache
pub persistent: bool,
/// Cache directory path
pub cache_dir: String,
}
impl Default for Config {
fn default() -> Self {
Self {
ocr: OcrConfig {
confidence_threshold: 0.7,
timeout: 30,
use_gpu: false,
languages: vec!["en".to_string()],
detect_equations: true,
detect_tables: true,
},
model: ModelConfig {
model_path: "models/scipix-ocr".to_string(),
version: "1.0.0".to_string(),
batch_size: 1,
precision: "fp32".to_string(),
quantize: false,
},
preprocess: PreprocessConfig {
auto_rotate: true,
denoise: true,
enhance_contrast: true,
binarize: false,
target_dpi: 300,
max_dimension: 4096,
},
output: OutputConfig {
formats: vec!["latex".to_string()],
include_confidence: true,
include_bbox: false,
pretty_print: true,
include_metadata: false,
},
performance: PerformanceConfig {
num_threads: num_cpus::get(),
parallel: true,
memory_limit: 2048,
profile: false,
},
cache: CacheConfig {
enabled: true,
capacity: 1000,
similarity_threshold: 0.95,
ttl: 3600,
vector_dimension: 512,
persistent: false,
cache_dir: ".cache/scipix".to_string(),
},
}
}
}
impl Config {
/// Load configuration from TOML file
///
/// # Arguments
///
/// * `path` - Path to TOML configuration file
///
/// # Examples
///
/// ```rust,no_run
/// use ruvector_scipix::Config;
///
/// let config = Config::from_file("scipix.toml").unwrap();
/// ```
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let content = std::fs::read_to_string(path)?;
let config: Config = toml::from_str(&content)?;
config.validate()?;
Ok(config)
}
/// Save configuration to TOML file
///
/// # Arguments
///
/// * `path` - Path to save TOML configuration
pub fn to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let content = toml::to_string_pretty(self)?;
std::fs::write(path, content)?;
Ok(())
}
/// Load configuration from environment variables
///
/// Environment variables should be prefixed with `MATHPIX_`
/// and use double underscores for nested fields.
///
/// # Examples
///
/// ```bash
/// export MATHPIX_OCR__CONFIDENCE_THRESHOLD=0.8
/// export MATHPIX_MODEL__BATCH_SIZE=4
/// ```
pub fn from_env() -> Result<Self> {
let mut config = Self::default();
config.apply_env_overrides()?;
Ok(config)
}
/// Apply environment variable overrides
fn apply_env_overrides(&mut self) -> Result<()> {
// OCR overrides
if let Ok(val) = std::env::var("MATHPIX_OCR__CONFIDENCE_THRESHOLD") {
self.ocr.confidence_threshold = val
.parse()
.map_err(|_| ScipixError::Config("Invalid confidence_threshold".to_string()))?;
}
if let Ok(val) = std::env::var("MATHPIX_OCR__TIMEOUT") {
self.ocr.timeout = val
.parse()
.map_err(|_| ScipixError::Config("Invalid timeout".to_string()))?;
}
if let Ok(val) = std::env::var("MATHPIX_OCR__USE_GPU") {
self.ocr.use_gpu = val
.parse()
.map_err(|_| ScipixError::Config("Invalid use_gpu".to_string()))?;
}
// Model overrides
if let Ok(val) = std::env::var("MATHPIX_MODEL__PATH") {
self.model.model_path = val;
}
if let Ok(val) = std::env::var("MATHPIX_MODEL__BATCH_SIZE") {
self.model.batch_size = val
.parse()
.map_err(|_| ScipixError::Config("Invalid batch_size".to_string()))?;
}
// Cache overrides
if let Ok(val) = std::env::var("MATHPIX_CACHE__ENABLED") {
self.cache.enabled = val
.parse()
.map_err(|_| ScipixError::Config("Invalid cache enabled".to_string()))?;
}
if let Ok(val) = std::env::var("MATHPIX_CACHE__CAPACITY") {
self.cache.capacity = val
.parse()
.map_err(|_| ScipixError::Config("Invalid cache capacity".to_string()))?;
}
Ok(())
}
/// Validate configuration
pub fn validate(&self) -> Result<()> {
// Validate confidence threshold
if self.ocr.confidence_threshold < 0.0 || self.ocr.confidence_threshold > 1.0 {
return Err(ScipixError::Config(
"confidence_threshold must be between 0.0 and 1.0".to_string(),
));
}
// Validate similarity threshold
if self.cache.similarity_threshold < 0.0 || self.cache.similarity_threshold > 1.0 {
return Err(ScipixError::Config(
"similarity_threshold must be between 0.0 and 1.0".to_string(),
));
}
// Validate batch size
if self.model.batch_size == 0 {
return Err(ScipixError::Config(
"batch_size must be greater than 0".to_string(),
));
}
// Validate precision
let valid_precisions = ["fp16", "fp32", "int8"];
if !valid_precisions.contains(&self.model.precision.as_str()) {
return Err(ScipixError::Config(format!(
"precision must be one of: {:?}",
valid_precisions
)));
}
// Validate output formats
let valid_formats = ["latex", "mathml", "asciimath"];
for format in &self.output.formats {
if !valid_formats.contains(&format.as_str()) {
return Err(ScipixError::Config(format!(
"Invalid output format: {}. Must be one of: {:?}",
format, valid_formats
)));
}
}
Ok(())
}
/// Create high-accuracy preset configuration
pub fn high_accuracy() -> Self {
let mut config = Self::default();
config.ocr.confidence_threshold = 0.9;
config.model.precision = "fp32".to_string();
config.model.quantize = false;
config.preprocess.denoise = true;
config.preprocess.enhance_contrast = true;
config.cache.similarity_threshold = 0.98;
config
}
/// Create high-speed preset configuration
pub fn high_speed() -> Self {
let mut config = Self::default();
config.ocr.confidence_threshold = 0.6;
config.model.precision = "fp16".to_string();
config.model.quantize = true;
config.model.batch_size = 4;
config.preprocess.denoise = false;
config.preprocess.enhance_contrast = false;
config.performance.parallel = true;
config.cache.similarity_threshold = 0.85;
config
}
/// Create minimal configuration
pub fn minimal() -> Self {
let mut config = Self::default();
config.cache.enabled = false;
config.preprocess.denoise = false;
config.preprocess.enhance_contrast = false;
config.performance.parallel = false;
config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = Config::default();
assert!(config.validate().is_ok());
assert_eq!(config.ocr.confidence_threshold, 0.7);
assert!(config.cache.enabled);
}
#[test]
fn test_high_accuracy_config() {
let config = Config::high_accuracy();
assert!(config.validate().is_ok());
assert_eq!(config.ocr.confidence_threshold, 0.9);
assert_eq!(config.cache.similarity_threshold, 0.98);
}
#[test]
fn test_high_speed_config() {
let config = Config::high_speed();
assert!(config.validate().is_ok());
assert_eq!(config.model.precision, "fp16");
assert!(config.model.quantize);
}
#[test]
fn test_minimal_config() {
let config = Config::minimal();
assert!(config.validate().is_ok());
assert!(!config.cache.enabled);
}
#[test]
fn test_invalid_confidence_threshold() {
let mut config = Config::default();
config.ocr.confidence_threshold = 1.5;
assert!(config.validate().is_err());
}
#[test]
fn test_invalid_batch_size() {
let mut config = Config::default();
config.model.batch_size = 0;
assert!(config.validate().is_err());
}
#[test]
fn test_invalid_precision() {
let mut config = Config::default();
config.model.precision = "invalid".to_string();
assert!(config.validate().is_err());
}
#[test]
fn test_invalid_output_format() {
let mut config = Config::default();
config.output.formats = vec!["invalid".to_string()];
assert!(config.validate().is_err());
}
#[test]
fn test_toml_serialization() {
let config = Config::default();
let toml_str = toml::to_string(&config).unwrap();
let deserialized: Config = toml::from_str(&toml_str).unwrap();
assert_eq!(
config.ocr.confidence_threshold,
deserialized.ocr.confidence_threshold
);
}
}

View File

@@ -0,0 +1,228 @@
//! Error types for Ruvector-Scipix
//!
//! Comprehensive error handling with context, HTTP status mapping, and retry logic.
use std::io;
use thiserror::Error;
/// Result type alias for Scipix operations
pub type Result<T> = std::result::Result<T, ScipixError>;
/// Comprehensive error types for all Scipix operations
#[derive(Debug, Error)]
pub enum ScipixError {
/// Image loading or processing error
#[error("Image error: {0}")]
Image(String),
/// Machine learning model error
#[error("Model error: {0}")]
Model(String),
/// OCR processing error
#[error("OCR error: {0}")]
Ocr(String),
/// LaTeX generation or parsing error
#[error("LaTeX error: {0}")]
LaTeX(String),
/// Configuration error
#[error("Configuration error: {0}")]
Config(String),
/// I/O error
#[error("I/O error: {0}")]
Io(#[from] io::Error),
/// Serialization/deserialization error
#[error("Serialization error: {0}")]
Serialization(String),
/// Invalid input error
#[error("Invalid input: {0}")]
InvalidInput(String),
/// Operation timeout
#[error("Timeout: operation took longer than {0}s")]
Timeout(u64),
/// Resource not found
#[error("Not found: {0}")]
NotFound(String),
/// Authentication error
#[error("Authentication error: {0}")]
Auth(String),
/// Rate limit exceeded
#[error("Rate limit exceeded: {0}")]
RateLimit(String),
/// Internal error
#[error("Internal error: {0}")]
Internal(String),
}
impl ScipixError {
/// Check if the error is retryable
///
/// # Returns
///
/// `true` if the operation should be retried, `false` otherwise
///
/// # Examples
///
/// ```rust
/// use ruvector_scipix::ScipixError;
///
/// let timeout_error = ScipixError::Timeout(30);
/// assert!(timeout_error.is_retryable());
///
/// let config_error = ScipixError::Config("Invalid parameter".to_string());
/// assert!(!config_error.is_retryable());
/// ```
pub fn is_retryable(&self) -> bool {
match self {
// Retryable errors
ScipixError::Timeout(_) => true,
ScipixError::RateLimit(_) => true,
ScipixError::Io(_) => true,
ScipixError::Internal(_) => true,
// Non-retryable errors
ScipixError::Image(_) => false,
ScipixError::Model(_) => false,
ScipixError::Ocr(_) => false,
ScipixError::LaTeX(_) => false,
ScipixError::Config(_) => false,
ScipixError::Serialization(_) => false,
ScipixError::InvalidInput(_) => false,
ScipixError::NotFound(_) => false,
ScipixError::Auth(_) => false,
}
}
/// Map error to HTTP status code
///
/// # Returns
///
/// HTTP status code representing the error type
///
/// # Examples
///
/// ```rust
/// use ruvector_scipix::ScipixError;
///
/// let auth_error = ScipixError::Auth("Invalid token".to_string());
/// assert_eq!(auth_error.status_code(), 401);
///
/// let not_found = ScipixError::NotFound("Model not found".to_string());
/// assert_eq!(not_found.status_code(), 404);
/// ```
pub fn status_code(&self) -> u16 {
match self {
ScipixError::Auth(_) => 401,
ScipixError::NotFound(_) => 404,
ScipixError::InvalidInput(_) => 400,
ScipixError::RateLimit(_) => 429,
ScipixError::Timeout(_) => 408,
ScipixError::Config(_) => 400,
ScipixError::Internal(_) => 500,
_ => 500,
}
}
/// Get error category for logging and metrics
pub fn category(&self) -> &'static str {
match self {
ScipixError::Image(_) => "image",
ScipixError::Model(_) => "model",
ScipixError::Ocr(_) => "ocr",
ScipixError::LaTeX(_) => "latex",
ScipixError::Config(_) => "config",
ScipixError::Io(_) => "io",
ScipixError::Serialization(_) => "serialization",
ScipixError::InvalidInput(_) => "invalid_input",
ScipixError::Timeout(_) => "timeout",
ScipixError::NotFound(_) => "not_found",
ScipixError::Auth(_) => "auth",
ScipixError::RateLimit(_) => "rate_limit",
ScipixError::Internal(_) => "internal",
}
}
}
// Conversion from serde_json::Error
impl From<serde_json::Error> for ScipixError {
fn from(err: serde_json::Error) -> Self {
ScipixError::Serialization(err.to_string())
}
}
// Conversion from toml::de::Error
impl From<toml::de::Error> for ScipixError {
fn from(err: toml::de::Error) -> Self {
ScipixError::Config(err.to_string())
}
}
// Conversion from toml::ser::Error
impl From<toml::ser::Error> for ScipixError {
fn from(err: toml::ser::Error) -> Self {
ScipixError::Serialization(err.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_display() {
let err = ScipixError::Image("Failed to load".to_string());
assert_eq!(err.to_string(), "Image error: Failed to load");
}
#[test]
fn test_is_retryable() {
assert!(ScipixError::Timeout(30).is_retryable());
assert!(ScipixError::RateLimit("Exceeded".to_string()).is_retryable());
assert!(!ScipixError::Config("Invalid".to_string()).is_retryable());
assert!(!ScipixError::Auth("Unauthorized".to_string()).is_retryable());
}
#[test]
fn test_status_codes() {
assert_eq!(ScipixError::Auth("".to_string()).status_code(), 401);
assert_eq!(ScipixError::NotFound("".to_string()).status_code(), 404);
assert_eq!(ScipixError::InvalidInput("".to_string()).status_code(), 400);
assert_eq!(ScipixError::RateLimit("".to_string()).status_code(), 429);
assert_eq!(ScipixError::Timeout(0).status_code(), 408);
assert_eq!(ScipixError::Internal("".to_string()).status_code(), 500);
}
#[test]
fn test_category() {
assert_eq!(ScipixError::Image("".to_string()).category(), "image");
assert_eq!(ScipixError::Model("".to_string()).category(), "model");
assert_eq!(ScipixError::Ocr("".to_string()).category(), "ocr");
assert_eq!(ScipixError::LaTeX("".to_string()).category(), "latex");
assert_eq!(ScipixError::Config("".to_string()).category(), "config");
assert_eq!(ScipixError::Auth("".to_string()).category(), "auth");
}
#[test]
fn test_from_io_error() {
let io_err = io::Error::new(io::ErrorKind::NotFound, "File not found");
let scipix_err: ScipixError = io_err.into();
assert!(matches!(scipix_err, ScipixError::Io(_)));
}
#[test]
fn test_from_json_error() {
let json_err = serde_json::from_str::<serde_json::Value>("invalid json").unwrap_err();
let scipix_err: ScipixError = json_err.into();
assert!(matches!(scipix_err, ScipixError::Serialization(_)));
}
}

129
examples/scipix/src/lib.rs Normal file
View File

@@ -0,0 +1,129 @@
//! # Ruvector-Scipix
//!
//! A high-performance Rust implementation of Scipix OCR for mathematical expressions and equations.
//! Built on top of ruvector-core for efficient vector-based caching and similarity search.
//!
//! ## Features
//!
//! - **Mathematical OCR**: Extract LaTeX from images of equations
//! - **Vector Caching**: Intelligent caching using image embeddings
//! - **Multiple Formats**: Support for LaTeX, MathML, AsciiMath
//! - **High Performance**: Parallel processing and efficient caching
//! - **Configurable**: Extensive configuration options via TOML or API
//!
//! ## Quick Start
//!
//! ```rust,no_run
//! use ruvector_scipix::{Config, OcrEngine, Result};
//!
//! #[tokio::main]
//! async fn main() -> Result<()> {
//! // Load configuration
//! let config = Config::from_file("scipix.toml")?;
//!
//! // Create OCR engine
//! let engine = OcrEngine::new(config).await?;
//!
//! // Process image
//! let result = engine.process_image("equation.png").await?;
//! println!("LaTeX: {}", result.latex);
//!
//! Ok(())
//! }
//! ```
//!
//! ## Architecture
//!
//! - **config**: Configuration management with TOML support
//! - **error**: Comprehensive error types with context
//! - **math**: LaTeX and mathematical format handling
//! - **ocr**: Core OCR processing engine
//! - **output**: Output formatting and serialization
//! - **preprocess**: Image preprocessing pipeline
//! - **cache**: Vector-based intelligent caching
// Module declarations
pub mod api;
pub mod cli;
pub mod config;
pub mod error;
#[cfg(feature = "cache")]
pub mod cache;
#[cfg(feature = "ocr")]
pub mod ocr;
#[cfg(feature = "math")]
pub mod math;
#[cfg(feature = "preprocess")]
pub mod preprocess;
// Output module is always available
pub mod output;
// Performance optimizations
#[cfg(feature = "optimize")]
pub mod optimize;
// WebAssembly bindings
#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
pub mod wasm;
// Public re-exports
pub use api::{state::AppState, ApiServer};
pub use cli::{Cli, Commands};
pub use config::{
CacheConfig, Config, ModelConfig, OcrConfig, OutputConfig, PerformanceConfig, PreprocessConfig,
};
pub use error::{Result, ScipixError};
#[cfg(feature = "cache")]
pub use cache::CacheManager;
/// Library version
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
/// Default configuration preset
pub fn default_config() -> Config {
Config::default()
}
/// High-accuracy configuration preset
pub fn high_accuracy_config() -> Config {
Config::high_accuracy()
}
/// High-speed configuration preset
pub fn high_speed_config() -> Config {
Config::high_speed()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_version() {
assert!(!VERSION.is_empty());
}
#[test]
fn test_default_config() {
let config = default_config();
assert!(config.validate().is_ok());
}
#[test]
fn test_high_accuracy_config() {
let config = high_accuracy_config();
assert!(config.validate().is_ok());
}
#[test]
fn test_high_speed_config() {
let config = high_speed_config();
assert!(config.validate().is_ok());
}
}

View File

@@ -0,0 +1,465 @@
//! AsciiMath generation from mathematical AST
//!
//! This module converts mathematical AST nodes to AsciiMath notation,
//! a simplified plain-text format for mathematical expressions.
use crate::math::ast::{BinaryOp, BracketType, LargeOpType, MathExpr, MathNode, UnaryOp};
/// AsciiMath generator for mathematical expressions
pub struct AsciiMathGenerator {
/// Use Unicode symbols (true) or ASCII approximations (false)
unicode: bool,
}
impl AsciiMathGenerator {
/// Create a new AsciiMath generator with Unicode support
pub fn new() -> Self {
Self { unicode: true }
}
/// Create an ASCII-only generator
pub fn ascii_only() -> Self {
Self { unicode: false }
}
/// Generate AsciiMath string from a mathematical expression
pub fn generate(&self, expr: &MathExpr) -> String {
self.generate_node(&expr.root, None)
}
/// Generate AsciiMath for a single node
fn generate_node(&self, node: &MathNode, parent_precedence: Option<u8>) -> String {
match node {
MathNode::Symbol { value, .. } => value.clone(),
MathNode::Number { value, .. } => value.clone(),
MathNode::Binary { op, left, right } => {
let precedence = op.precedence();
let needs_parens = parent_precedence.map_or(false, |p| precedence < p);
let left_str = self.generate_node(left, Some(precedence));
let right_str = self.generate_node(
right,
Some(if op.is_left_associative() {
precedence
} else {
precedence + 1
}),
);
let op_str = self.binary_op_to_asciimath(op);
let result = format!("{} {} {}", left_str, op_str, right_str);
if needs_parens {
format!("({})", result)
} else {
result
}
}
MathNode::Unary { op, operand } => {
let op_str = self.unary_op_to_asciimath(op);
let operand_str = self.generate_node(operand, Some(70));
format!("{}{}", op_str, operand_str)
}
MathNode::Fraction {
numerator,
denominator,
} => {
let num_str = self.generate_node(numerator, None);
let den_str = self.generate_node(denominator, None);
format!("({})/({})", num_str, den_str)
}
MathNode::Radical { index, radicand } => {
let rad_str = self.generate_node(radicand, None);
if let Some(idx) = index {
let idx_str = self.generate_node(idx, None);
format!("root({})({} )", idx_str, rad_str)
} else {
format!("sqrt({})", rad_str)
}
}
MathNode::Script {
base,
subscript,
superscript,
} => {
let base_str = self.generate_node(base, Some(65));
let mut result = base_str;
if let Some(sub) = subscript {
let sub_str = self.generate_node(sub, None);
result.push_str(&format!("_{{{}}}", sub_str));
}
if let Some(sup) = superscript {
let sup_str = self.generate_node(sup, None);
result.push_str(&format!("^{{{}}}", sup_str));
}
result
}
MathNode::Function { name, argument } => {
let arg_str = self.generate_node(argument, None);
format!("{}({})", name, arg_str)
}
MathNode::Matrix { rows, .. } => {
let mut content = String::new();
content.push('[');
for (i, row) in rows.iter().enumerate() {
if i > 0 {
content.push_str("; ");
}
for (j, elem) in row.iter().enumerate() {
if j > 0 {
content.push_str(", ");
}
content.push_str(&self.generate_node(elem, None));
}
}
content.push(']');
content
}
MathNode::Group {
content,
bracket_type,
} => {
let content_str = self.generate_node(content, None);
let (open, close) = match bracket_type {
BracketType::Parentheses => ("(", ")"),
BracketType::Brackets => ("[", "]"),
BracketType::Braces => ("{", "}"),
BracketType::AngleBrackets => {
if self.unicode {
("", "")
} else {
("<", ">")
}
}
BracketType::Vertical => ("|", "|"),
BracketType::DoubleVertical => {
if self.unicode {
("", "")
} else {
("||", "||")
}
}
BracketType::Floor => {
if self.unicode {
("", "")
} else {
("|_", "_|")
}
}
BracketType::Ceiling => {
if self.unicode {
("", "")
} else {
("|^", "^|")
}
}
BracketType::None => ("", ""),
};
format!("{}{}{}", open, content_str, close)
}
MathNode::LargeOp {
op_type,
lower,
upper,
content,
} => {
let op_str = self.large_op_to_asciimath(op_type);
let content_str = self.generate_node(content, None);
let mut result = op_str.to_string();
if let Some(low) = lower {
let low_str = self.generate_node(low, None);
result.push_str(&format!("_{{{}}}", low_str));
}
if let Some(up) = upper {
let up_str = self.generate_node(up, None);
result.push_str(&format!("^{{{}}}", up_str));
}
format!("{} {}", result, content_str)
}
MathNode::Sequence { elements } => elements
.iter()
.map(|e| self.generate_node(e, None))
.collect::<Vec<_>>()
.join(", "),
MathNode::Text { content } => {
format!("\"{}\"", content)
}
MathNode::Empty => String::new(),
}
}
/// Convert binary operator to AsciiMath
fn binary_op_to_asciimath<'a>(&self, op: &'a BinaryOp) -> &'a str {
if self.unicode {
match op {
BinaryOp::Add => "+",
BinaryOp::Subtract => "-",
BinaryOp::Multiply => "×",
BinaryOp::Divide => "÷",
BinaryOp::Power => "^",
BinaryOp::Equal => "=",
BinaryOp::NotEqual => "",
BinaryOp::Less => "<",
BinaryOp::Greater => ">",
BinaryOp::LessEqual => "",
BinaryOp::GreaterEqual => "",
BinaryOp::ApproxEqual => "",
BinaryOp::Equivalent => "",
BinaryOp::Similar => "",
BinaryOp::Congruent => "",
BinaryOp::Proportional => "",
BinaryOp::Custom(s) => s,
}
} else {
match op {
BinaryOp::Add => "+",
BinaryOp::Subtract => "-",
BinaryOp::Multiply => "*",
BinaryOp::Divide => "/",
BinaryOp::Power => "^",
BinaryOp::Equal => "=",
BinaryOp::NotEqual => "!=",
BinaryOp::Less => "<",
BinaryOp::Greater => ">",
BinaryOp::LessEqual => "<=",
BinaryOp::GreaterEqual => ">=",
BinaryOp::ApproxEqual => "~~",
BinaryOp::Equivalent => "-=",
BinaryOp::Similar => "~",
BinaryOp::Congruent => "~=",
BinaryOp::Proportional => "prop",
BinaryOp::Custom(s) => s.as_str(),
}
}
}
/// Convert unary operator to AsciiMath
fn unary_op_to_asciimath<'a>(&self, op: &'a UnaryOp) -> &'a str {
match op {
UnaryOp::Plus => "+",
UnaryOp::Minus => "-",
UnaryOp::Not => {
if self.unicode {
"¬"
} else {
"not "
}
}
UnaryOp::Custom(s) => s.as_str(),
}
}
/// Convert large operator to AsciiMath
fn large_op_to_asciimath(&self, op: &LargeOpType) -> &str {
if self.unicode {
match op {
LargeOpType::Sum => "",
LargeOpType::Product => "",
LargeOpType::Integral => "",
LargeOpType::DoubleIntegral => "",
LargeOpType::TripleIntegral => "",
LargeOpType::ContourIntegral => "",
LargeOpType::Union => "",
LargeOpType::Intersection => "",
LargeOpType::Coproduct => "",
LargeOpType::DirectSum => "",
LargeOpType::Custom(_) => "sum",
}
} else {
match op {
LargeOpType::Sum => "sum",
LargeOpType::Product => "prod",
LargeOpType::Integral => "int",
LargeOpType::DoubleIntegral => "iint",
LargeOpType::TripleIntegral => "iiint",
LargeOpType::ContourIntegral => "oint",
LargeOpType::Union => "cup",
LargeOpType::Intersection => "cap",
LargeOpType::Coproduct => "coprod",
LargeOpType::DirectSum => "oplus",
LargeOpType::Custom(_) => "sum",
}
}
}
}
impl Default for AsciiMathGenerator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_number() {
let expr = MathExpr::new(
MathNode::Number {
value: "42".to_string(),
is_decimal: false,
},
1.0,
);
let gen = AsciiMathGenerator::new();
assert_eq!(gen.generate(&expr), "42");
}
#[test]
fn test_addition() {
let expr = MathExpr::new(
MathNode::Binary {
op: BinaryOp::Add,
left: Box::new(MathNode::Number {
value: "1".to_string(),
is_decimal: false,
}),
right: Box::new(MathNode::Number {
value: "2".to_string(),
is_decimal: false,
}),
},
1.0,
);
let gen = AsciiMathGenerator::new();
assert_eq!(gen.generate(&expr), "1 + 2");
}
#[test]
fn test_fraction() {
let expr = MathExpr::new(
MathNode::Fraction {
numerator: Box::new(MathNode::Number {
value: "1".to_string(),
is_decimal: false,
}),
denominator: Box::new(MathNode::Number {
value: "2".to_string(),
is_decimal: false,
}),
},
1.0,
);
let gen = AsciiMathGenerator::new();
assert_eq!(gen.generate(&expr), "(1)/(2)");
}
#[test]
fn test_sqrt() {
let expr = MathExpr::new(
MathNode::Radical {
index: None,
radicand: Box::new(MathNode::Number {
value: "2".to_string(),
is_decimal: false,
}),
},
1.0,
);
let gen = AsciiMathGenerator::new();
assert_eq!(gen.generate(&expr), "sqrt(2)");
}
#[test]
fn test_superscript() {
let expr = MathExpr::new(
MathNode::Script {
base: Box::new(MathNode::Symbol {
value: "x".to_string(),
unicode: Some('x'),
}),
subscript: None,
superscript: Some(Box::new(MathNode::Number {
value: "2".to_string(),
is_decimal: false,
})),
},
1.0,
);
let gen = AsciiMathGenerator::new();
assert_eq!(gen.generate(&expr), "x^{2}");
}
#[test]
fn test_unicode_vs_ascii() {
let expr = MathExpr::new(
MathNode::Binary {
op: BinaryOp::Multiply,
left: Box::new(MathNode::Number {
value: "2".to_string(),
is_decimal: false,
}),
right: Box::new(MathNode::Number {
value: "3".to_string(),
is_decimal: false,
}),
},
1.0,
);
let gen_unicode = AsciiMathGenerator::new();
assert_eq!(gen_unicode.generate(&expr), "2 × 3");
let gen_ascii = AsciiMathGenerator::ascii_only();
assert_eq!(gen_ascii.generate(&expr), "2 * 3");
}
#[test]
fn test_matrix() {
let expr = MathExpr::new(
MathNode::Matrix {
rows: vec![
vec![
MathNode::Number {
value: "1".to_string(),
is_decimal: false,
},
MathNode::Number {
value: "2".to_string(),
is_decimal: false,
},
],
vec![
MathNode::Number {
value: "3".to_string(),
is_decimal: false,
},
MathNode::Number {
value: "4".to_string(),
is_decimal: false,
},
],
],
bracket_type: BracketType::Brackets,
},
1.0,
);
let gen = AsciiMathGenerator::new();
assert_eq!(gen.generate(&expr), "[1, 2; 3, 4]");
}
}

View File

@@ -0,0 +1,437 @@
//! Abstract Syntax Tree definitions for mathematical expressions
//!
//! This module defines the complete AST structure for representing mathematical
//! expressions including symbols, operators, fractions, matrices, and more.
use serde::{Deserialize, Serialize};
use std::fmt;
/// A complete mathematical expression with confidence score
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MathExpr {
/// Root node of the expression tree
pub root: MathNode,
/// Confidence score (0.0 to 1.0) from OCR recognition
pub confidence: f32,
}
impl MathExpr {
/// Create a new mathematical expression
pub fn new(root: MathNode, confidence: f32) -> Self {
Self { root, confidence }
}
/// Accept a visitor for tree traversal
pub fn accept<V: MathVisitor>(&self, visitor: &mut V) {
self.root.accept(visitor);
}
}
/// Main AST node representing any mathematical construct
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum MathNode {
/// A mathematical symbol (variable, Greek letter, operator)
Symbol {
value: String,
unicode: Option<char>,
},
/// A numeric value
Number {
value: String,
/// Whether this is part of a decimal number
is_decimal: bool,
},
/// Binary operation (a op b)
Binary {
op: BinaryOp,
left: Box<MathNode>,
right: Box<MathNode>,
},
/// Unary operation (op a)
Unary { op: UnaryOp, operand: Box<MathNode> },
/// Fraction (numerator / denominator)
Fraction {
numerator: Box<MathNode>,
denominator: Box<MathNode>,
},
/// Radical (√, ∛, etc.)
Radical {
/// Index of the radical (2 for square root, 3 for cube root, etc.)
index: Option<Box<MathNode>>,
radicand: Box<MathNode>,
},
/// Subscript or superscript
Script {
base: Box<MathNode>,
subscript: Option<Box<MathNode>>,
superscript: Option<Box<MathNode>>,
},
/// Function application (sin, cos, log, etc.)
Function {
name: String,
argument: Box<MathNode>,
},
/// Matrix or vector
Matrix {
rows: Vec<Vec<MathNode>>,
bracket_type: BracketType,
},
/// Grouped expression with delimiters
Group {
content: Box<MathNode>,
bracket_type: BracketType,
},
/// Large operators (∑, ∫, ∏, etc.)
LargeOp {
op_type: LargeOpType,
lower: Option<Box<MathNode>>,
upper: Option<Box<MathNode>>,
content: Box<MathNode>,
},
/// Sequence of expressions (e.g., function arguments)
Sequence { elements: Vec<MathNode> },
/// Text annotation in math mode
Text { content: String },
/// Empty/placeholder node
Empty,
}
impl MathNode {
/// Accept a visitor for tree traversal
pub fn accept<V: MathVisitor>(&self, visitor: &mut V) {
visitor.visit(self);
match self {
MathNode::Binary { left, right, .. } => {
left.accept(visitor);
right.accept(visitor);
}
MathNode::Unary { operand, .. } => {
operand.accept(visitor);
}
MathNode::Fraction {
numerator,
denominator,
} => {
numerator.accept(visitor);
denominator.accept(visitor);
}
MathNode::Radical { index, radicand } => {
if let Some(idx) = index {
idx.accept(visitor);
}
radicand.accept(visitor);
}
MathNode::Script {
base,
subscript,
superscript,
} => {
base.accept(visitor);
if let Some(sub) = subscript {
sub.accept(visitor);
}
if let Some(sup) = superscript {
sup.accept(visitor);
}
}
MathNode::Function { argument, .. } => {
argument.accept(visitor);
}
MathNode::Matrix { rows, .. } => {
for row in rows {
for elem in row {
elem.accept(visitor);
}
}
}
MathNode::Group { content, .. } => {
content.accept(visitor);
}
MathNode::LargeOp {
lower,
upper,
content,
..
} => {
if let Some(l) = lower {
l.accept(visitor);
}
if let Some(u) = upper {
u.accept(visitor);
}
content.accept(visitor);
}
MathNode::Sequence { elements } => {
for elem in elements {
elem.accept(visitor);
}
}
_ => {}
}
}
}
/// Binary operators
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum BinaryOp {
Add,
Subtract,
Multiply,
Divide,
Power,
Equal,
NotEqual,
Less,
Greater,
LessEqual,
GreaterEqual,
ApproxEqual,
Equivalent,
Similar,
Congruent,
Proportional,
/// Custom operator with LaTeX representation
Custom(String),
}
impl BinaryOp {
/// Get precedence level (higher = binds tighter)
pub fn precedence(&self) -> u8 {
match self {
BinaryOp::Power => 60,
BinaryOp::Multiply | BinaryOp::Divide => 50,
BinaryOp::Add | BinaryOp::Subtract => 40,
BinaryOp::Equal
| BinaryOp::NotEqual
| BinaryOp::Less
| BinaryOp::Greater
| BinaryOp::LessEqual
| BinaryOp::GreaterEqual
| BinaryOp::ApproxEqual
| BinaryOp::Equivalent
| BinaryOp::Similar
| BinaryOp::Congruent
| BinaryOp::Proportional => 30,
BinaryOp::Custom(_) => 35,
}
}
/// Check if operator is left-associative
pub fn is_left_associative(&self) -> bool {
!matches!(self, BinaryOp::Power)
}
}
impl fmt::Display for BinaryOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BinaryOp::Add => write!(f, "+"),
BinaryOp::Subtract => write!(f, "-"),
BinaryOp::Multiply => write!(f, "×"),
BinaryOp::Divide => write!(f, "÷"),
BinaryOp::Power => write!(f, "^"),
BinaryOp::Equal => write!(f, "="),
BinaryOp::NotEqual => write!(f, ""),
BinaryOp::Less => write!(f, "<"),
BinaryOp::Greater => write!(f, ">"),
BinaryOp::LessEqual => write!(f, ""),
BinaryOp::GreaterEqual => write!(f, ""),
BinaryOp::ApproxEqual => write!(f, ""),
BinaryOp::Equivalent => write!(f, ""),
BinaryOp::Similar => write!(f, ""),
BinaryOp::Congruent => write!(f, ""),
BinaryOp::Proportional => write!(f, ""),
BinaryOp::Custom(s) => write!(f, "{}", s),
}
}
}
/// Unary operators
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum UnaryOp {
Plus,
Minus,
Not,
/// Custom unary operator
Custom(String),
}
impl fmt::Display for UnaryOp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
UnaryOp::Plus => write!(f, "+"),
UnaryOp::Minus => write!(f, "-"),
UnaryOp::Not => write!(f, "¬"),
UnaryOp::Custom(s) => write!(f, "{}", s),
}
}
}
/// Large operator types (∑, ∫, etc.)
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum LargeOpType {
Sum, // ∑
Product, // ∏
Integral, // ∫
DoubleIntegral, // ∬
TripleIntegral, // ∭
ContourIntegral, // ∮
Union, //
Intersection, // ⋂
Coproduct, // ∐
DirectSum, // ⊕
Custom(String),
}
impl fmt::Display for LargeOpType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
LargeOpType::Sum => write!(f, ""),
LargeOpType::Product => write!(f, ""),
LargeOpType::Integral => write!(f, ""),
LargeOpType::DoubleIntegral => write!(f, ""),
LargeOpType::TripleIntegral => write!(f, ""),
LargeOpType::ContourIntegral => write!(f, ""),
LargeOpType::Union => write!(f, ""),
LargeOpType::Intersection => write!(f, ""),
LargeOpType::Coproduct => write!(f, ""),
LargeOpType::DirectSum => write!(f, ""),
LargeOpType::Custom(s) => write!(f, "{}", s),
}
}
}
/// Bracket types for grouping and matrices
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BracketType {
Parentheses, // ( )
Brackets, // [ ]
Braces, // { }
AngleBrackets, // ⟨ ⟩
Vertical, // | |
DoubleVertical, // ‖ ‖
Floor, // ⌊ ⌋
Ceiling, // ⌈ ⌉
None, // No brackets
}
impl BracketType {
/// Get opening delimiter
pub fn opening(&self) -> &str {
match self {
BracketType::Parentheses => "(",
BracketType::Brackets => "[",
BracketType::Braces => "{",
BracketType::AngleBrackets => "",
BracketType::Vertical => "|",
BracketType::DoubleVertical => "",
BracketType::Floor => "",
BracketType::Ceiling => "",
BracketType::None => "",
}
}
/// Get closing delimiter
pub fn closing(&self) -> &str {
match self {
BracketType::Parentheses => ")",
BracketType::Brackets => "]",
BracketType::Braces => "}",
BracketType::AngleBrackets => "",
BracketType::Vertical => "|",
BracketType::DoubleVertical => "",
BracketType::Floor => "",
BracketType::Ceiling => "",
BracketType::None => "",
}
}
}
/// Visitor pattern for traversing the AST
pub trait MathVisitor {
fn visit(&mut self, node: &MathNode);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_binary_op_precedence() {
assert!(BinaryOp::Power.precedence() > BinaryOp::Multiply.precedence());
assert!(BinaryOp::Multiply.precedence() > BinaryOp::Add.precedence());
assert!(BinaryOp::Add.precedence() > BinaryOp::Equal.precedence());
}
#[test]
fn test_binary_op_associativity() {
assert!(BinaryOp::Add.is_left_associative());
assert!(BinaryOp::Multiply.is_left_associative());
assert!(!BinaryOp::Power.is_left_associative());
}
#[test]
fn test_bracket_delimiters() {
assert_eq!(BracketType::Parentheses.opening(), "(");
assert_eq!(BracketType::Parentheses.closing(), ")");
assert_eq!(BracketType::Brackets.opening(), "[");
assert_eq!(BracketType::Braces.closing(), "}");
}
#[test]
fn test_math_expr_creation() {
let expr = MathExpr::new(
MathNode::Number {
value: "42".to_string(),
is_decimal: false,
},
0.95,
);
assert_eq!(expr.confidence, 0.95);
}
#[test]
fn test_visitor_pattern() {
struct CountVisitor {
count: usize,
}
impl MathVisitor for CountVisitor {
fn visit(&mut self, _node: &MathNode) {
self.count += 1;
}
}
let expr = MathExpr::new(
MathNode::Binary {
op: BinaryOp::Add,
left: Box::new(MathNode::Number {
value: "1".to_string(),
is_decimal: false,
}),
right: Box::new(MathNode::Number {
value: "2".to_string(),
is_decimal: false,
}),
},
1.0,
);
let mut visitor = CountVisitor { count: 0 };
expr.accept(&mut visitor);
assert_eq!(visitor.count, 3); // Binary + 2 numbers
}
}

View File

@@ -0,0 +1,608 @@
//! LaTeX generation from mathematical AST
//!
//! This module converts mathematical AST nodes to LaTeX strings with proper
//! formatting, precedence handling, and delimiter placement.
use crate::math::ast::{BinaryOp, BracketType, LargeOpType, MathExpr, MathNode, UnaryOp};
use crate::math::symbols::unicode_to_latex;
/// Configuration for LaTeX generation
#[derive(Debug, Clone)]
pub struct LaTeXConfig {
/// Use display style (true) or inline style (false)
pub display_style: bool,
/// Use \left and \right for delimiters
pub auto_size_delimiters: bool,
/// Insert spaces around operators
pub spacing: bool,
}
impl Default for LaTeXConfig {
fn default() -> Self {
Self {
display_style: false,
auto_size_delimiters: true,
spacing: true,
}
}
}
/// LaTeX generator for mathematical expressions
pub struct LaTeXGenerator {
config: LaTeXConfig,
}
impl LaTeXGenerator {
/// Create a new LaTeX generator with default configuration
pub fn new() -> Self {
Self {
config: LaTeXConfig::default(),
}
}
/// Create a new LaTeX generator with custom configuration
pub fn with_config(config: LaTeXConfig) -> Self {
Self { config }
}
/// Generate LaTeX string from a mathematical expression
pub fn generate(&self, expr: &MathExpr) -> String {
self.generate_node(&expr.root, None)
}
/// Generate LaTeX for a single node
fn generate_node(&self, node: &MathNode, parent_precedence: Option<u8>) -> String {
match node {
MathNode::Symbol { value, unicode } => {
if let Some(c) = unicode {
if let Some(latex) = unicode_to_latex(*c) {
return format!("\\{}", latex);
}
}
value.clone()
}
MathNode::Number { value, .. } => value.clone(),
MathNode::Binary { op, left, right } => {
let precedence = op.precedence();
let needs_parens = parent_precedence.map_or(false, |p| precedence < p);
let left_str = self.generate_node(left, Some(precedence));
let right_str = self.generate_node(
right,
Some(if op.is_left_associative() {
precedence
} else {
precedence + 1
}),
);
let op_str = self.binary_op_to_latex(op);
let space = if self.config.spacing { " " } else { "" };
let result = format!("{}{}{}{}{}", left_str, space, op_str, space, right_str);
if needs_parens {
self.wrap_parens(&result)
} else {
result
}
}
MathNode::Unary { op, operand } => {
let op_str = self.unary_op_to_latex(op);
let operand_str = self.generate_node(operand, Some(70)); // High precedence
format!("{}{}", op_str, operand_str)
}
MathNode::Fraction {
numerator,
denominator,
} => {
let num_str = self.generate_node(numerator, None);
let den_str = self.generate_node(denominator, None);
format!("\\frac{{{}}}{{{}}}", num_str, den_str)
}
MathNode::Radical { index, radicand } => {
let rad_str = self.generate_node(radicand, None);
if let Some(idx) = index {
let idx_str = self.generate_node(idx, None);
format!("\\sqrt[{}]{{{}}}", idx_str, rad_str)
} else {
format!("\\sqrt{{{}}}", rad_str)
}
}
MathNode::Script {
base,
subscript,
superscript,
} => {
let base_str = self.generate_node(base, Some(65));
let mut result = base_str;
if let Some(sub) = subscript {
let sub_str = self.generate_node(sub, None);
result.push_str(&format!("_{{{}}}", sub_str));
}
if let Some(sup) = superscript {
let sup_str = self.generate_node(sup, None);
result.push_str(&format!("^{{{}}}", sup_str));
}
result
}
MathNode::Function { name, argument } => {
let arg_str = self.generate_node(argument, None);
// Check if it's a standard function
if is_standard_function(name) {
format!("\\{} {}", name, arg_str)
} else {
format!("\\text{{{}}}({})", name, arg_str)
}
}
MathNode::Matrix { rows, bracket_type } => {
let env = match bracket_type {
BracketType::Parentheses => "pmatrix",
BracketType::Brackets => "bmatrix",
BracketType::Braces => "Bmatrix",
BracketType::Vertical => "vmatrix",
BracketType::DoubleVertical => "Vmatrix",
_ => "matrix",
};
let mut content = String::new();
for (i, row) in rows.iter().enumerate() {
if i > 0 {
content.push_str(" \\\\ ");
}
for (j, elem) in row.iter().enumerate() {
if j > 0 {
content.push_str(" & ");
}
content.push_str(&self.generate_node(elem, None));
}
}
format!("\\begin{{{}}} {} \\end{{{}}}", env, content, env)
}
MathNode::Group {
content,
bracket_type,
} => {
let content_str = self.generate_node(content, None);
self.wrap_with_brackets(&content_str, *bracket_type)
}
MathNode::LargeOp {
op_type,
lower,
upper,
content,
} => {
let op_str = self.large_op_to_latex(op_type);
let content_str = self.generate_node(content, None);
let mut result = op_str;
if let Some(low) = lower {
let low_str = self.generate_node(low, None);
result.push_str(&format!("_{{{}}}", low_str));
}
if let Some(up) = upper {
let up_str = self.generate_node(up, None);
result.push_str(&format!("^{{{}}}", up_str));
}
format!("{} {}", result, content_str)
}
MathNode::Sequence { elements } => elements
.iter()
.map(|e| self.generate_node(e, None))
.collect::<Vec<_>>()
.join(", "),
MathNode::Text { content } => {
format!("\\text{{{}}}", content)
}
MathNode::Empty => String::new(),
}
}
/// Convert binary operator to LaTeX
fn binary_op_to_latex(&self, op: &BinaryOp) -> String {
match op {
BinaryOp::Add => "+".to_string(),
BinaryOp::Subtract => "-".to_string(),
BinaryOp::Multiply => "\\times".to_string(),
BinaryOp::Divide => "\\div".to_string(),
BinaryOp::Power => "^".to_string(),
BinaryOp::Equal => "=".to_string(),
BinaryOp::NotEqual => "\\neq".to_string(),
BinaryOp::Less => "<".to_string(),
BinaryOp::Greater => ">".to_string(),
BinaryOp::LessEqual => "\\leq".to_string(),
BinaryOp::GreaterEqual => "\\geq".to_string(),
BinaryOp::ApproxEqual => "\\approx".to_string(),
BinaryOp::Equivalent => "\\equiv".to_string(),
BinaryOp::Similar => "\\sim".to_string(),
BinaryOp::Congruent => "\\cong".to_string(),
BinaryOp::Proportional => "\\propto".to_string(),
BinaryOp::Custom(s) => s.to_string(),
}
}
/// Convert unary operator to LaTeX
fn unary_op_to_latex(&self, op: &UnaryOp) -> String {
match op {
UnaryOp::Plus => "+".to_string(),
UnaryOp::Minus => "-".to_string(),
UnaryOp::Not => "\\neg".to_string(),
UnaryOp::Custom(s) => s.to_string(),
}
}
/// Convert large operator to LaTeX
fn large_op_to_latex(&self, op: &LargeOpType) -> String {
match op {
LargeOpType::Sum => "\\sum".to_string(),
LargeOpType::Product => "\\prod".to_string(),
LargeOpType::Integral => "\\int".to_string(),
LargeOpType::DoubleIntegral => "\\iint".to_string(),
LargeOpType::TripleIntegral => "\\iiint".to_string(),
LargeOpType::ContourIntegral => "\\oint".to_string(),
LargeOpType::Union => "\\bigcup".to_string(),
LargeOpType::Intersection => "\\bigcap".to_string(),
LargeOpType::Coproduct => "\\coprod".to_string(),
LargeOpType::DirectSum => "\\bigoplus".to_string(),
LargeOpType::Custom(s) => s.clone(),
}
}
/// Wrap content with brackets
fn wrap_with_brackets(&self, content: &str, bracket_type: BracketType) -> String {
let (left, right) = if self.config.auto_size_delimiters {
match bracket_type {
BracketType::Parentheses => ("\\left(", "\\right)"),
BracketType::Brackets => ("\\left[", "\\right]"),
BracketType::Braces => ("\\left\\{", "\\right\\}"),
BracketType::AngleBrackets => ("\\left\\langle", "\\right\\rangle"),
BracketType::Vertical => ("\\left|", "\\right|"),
BracketType::DoubleVertical => ("\\left\\|", "\\right\\|"),
BracketType::Floor => ("\\left\\lfloor", "\\right\\rfloor"),
BracketType::Ceiling => ("\\left\\lceil", "\\right\\rceil"),
BracketType::None => ("", ""),
}
} else {
match bracket_type {
BracketType::Parentheses => ("(", ")"),
BracketType::Brackets => ("[", "]"),
BracketType::Braces => ("\\{", "\\}"),
BracketType::AngleBrackets => ("\\langle", "\\rangle"),
BracketType::Vertical => ("|", "|"),
BracketType::DoubleVertical => ("\\|", "\\|"),
BracketType::Floor => ("\\lfloor", "\\rfloor"),
BracketType::Ceiling => ("\\lceil", "\\rceil"),
BracketType::None => ("", ""),
}
};
format!("{}{}{}", left, content, right)
}
/// Wrap content in parentheses
fn wrap_parens(&self, content: &str) -> String {
self.wrap_with_brackets(content, BracketType::Parentheses)
}
}
impl Default for LaTeXGenerator {
fn default() -> Self {
Self::new()
}
}
/// Check if a function name is a standard LaTeX function
fn is_standard_function(name: &str) -> bool {
matches!(
name,
"sin"
| "cos"
| "tan"
| "cot"
| "sec"
| "csc"
| "sinh"
| "cosh"
| "tanh"
| "coth"
| "arcsin"
| "arccos"
| "arctan"
| "ln"
| "log"
| "exp"
| "lim"
| "sup"
| "inf"
| "max"
| "min"
| "det"
| "dim"
| "ker"
| "deg"
| "gcd"
| "lcm"
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_number() {
let expr = MathExpr::new(
MathNode::Number {
value: "42".to_string(),
is_decimal: false,
},
1.0,
);
let gen = LaTeXGenerator::new();
assert_eq!(gen.generate(&expr), "42");
}
#[test]
fn test_simple_binary() {
let expr = MathExpr::new(
MathNode::Binary {
op: BinaryOp::Add,
left: Box::new(MathNode::Number {
value: "1".to_string(),
is_decimal: false,
}),
right: Box::new(MathNode::Number {
value: "2".to_string(),
is_decimal: false,
}),
},
1.0,
);
let gen = LaTeXGenerator::new();
assert_eq!(gen.generate(&expr), "1 + 2");
}
#[test]
fn test_fraction() {
let expr = MathExpr::new(
MathNode::Fraction {
numerator: Box::new(MathNode::Number {
value: "1".to_string(),
is_decimal: false,
}),
denominator: Box::new(MathNode::Number {
value: "2".to_string(),
is_decimal: false,
}),
},
1.0,
);
let gen = LaTeXGenerator::new();
assert_eq!(gen.generate(&expr), "\\frac{1}{2}");
}
#[test]
fn test_square_root() {
let expr = MathExpr::new(
MathNode::Radical {
index: None,
radicand: Box::new(MathNode::Number {
value: "2".to_string(),
is_decimal: false,
}),
},
1.0,
);
let gen = LaTeXGenerator::new();
assert_eq!(gen.generate(&expr), "\\sqrt{2}");
}
#[test]
fn test_nth_root() {
let expr = MathExpr::new(
MathNode::Radical {
index: Some(Box::new(MathNode::Number {
value: "3".to_string(),
is_decimal: false,
})),
radicand: Box::new(MathNode::Number {
value: "8".to_string(),
is_decimal: false,
}),
},
1.0,
);
let gen = LaTeXGenerator::new();
assert_eq!(gen.generate(&expr), "\\sqrt[3]{8}");
}
#[test]
fn test_superscript() {
let expr = MathExpr::new(
MathNode::Script {
base: Box::new(MathNode::Symbol {
value: "x".to_string(),
unicode: None,
}),
subscript: None,
superscript: Some(Box::new(MathNode::Number {
value: "2".to_string(),
is_decimal: false,
})),
},
1.0,
);
let gen = LaTeXGenerator::new();
assert_eq!(gen.generate(&expr), "x^{2}");
}
#[test]
fn test_subscript() {
let expr = MathExpr::new(
MathNode::Script {
base: Box::new(MathNode::Symbol {
value: "a".to_string(),
unicode: None,
}),
subscript: Some(Box::new(MathNode::Number {
value: "n".to_string(),
is_decimal: false,
})),
superscript: None,
},
1.0,
);
let gen = LaTeXGenerator::new();
assert_eq!(gen.generate(&expr), "a_{n}");
}
#[test]
fn test_complex_fraction() {
// (a + b) / (c - d)
let expr = MathExpr::new(
MathNode::Fraction {
numerator: Box::new(MathNode::Binary {
op: BinaryOp::Add,
left: Box::new(MathNode::Symbol {
value: "a".to_string(),
unicode: None,
}),
right: Box::new(MathNode::Symbol {
value: "b".to_string(),
unicode: None,
}),
}),
denominator: Box::new(MathNode::Binary {
op: BinaryOp::Subtract,
left: Box::new(MathNode::Symbol {
value: "c".to_string(),
unicode: None,
}),
right: Box::new(MathNode::Symbol {
value: "d".to_string(),
unicode: None,
}),
}),
},
1.0,
);
let gen = LaTeXGenerator::new();
assert_eq!(gen.generate(&expr), "\\frac{a + b}{c - d}");
}
#[test]
fn test_summation() {
// ∑_{i=1}^{n} i
let expr = MathExpr::new(
MathNode::LargeOp {
op_type: LargeOpType::Sum,
lower: Some(Box::new(MathNode::Binary {
op: BinaryOp::Equal,
left: Box::new(MathNode::Symbol {
value: "i".to_string(),
unicode: None,
}),
right: Box::new(MathNode::Number {
value: "1".to_string(),
is_decimal: false,
}),
})),
upper: Some(Box::new(MathNode::Symbol {
value: "n".to_string(),
unicode: None,
})),
content: Box::new(MathNode::Symbol {
value: "i".to_string(),
unicode: None,
}),
},
1.0,
);
let gen = LaTeXGenerator::new();
assert_eq!(gen.generate(&expr), "\\sum_{i = 1}^{n} i");
}
#[test]
fn test_integral() {
// ∫ x dx
let expr = MathExpr::new(
MathNode::LargeOp {
op_type: LargeOpType::Integral,
lower: None,
upper: None,
content: Box::new(MathNode::Sequence {
elements: vec![
MathNode::Symbol {
value: "x".to_string(),
unicode: None,
},
MathNode::Symbol {
value: "dx".to_string(),
unicode: None,
},
],
}),
},
1.0,
);
let gen = LaTeXGenerator::new();
assert_eq!(gen.generate(&expr), "\\int x, dx");
}
#[test]
fn test_matrix() {
let expr = MathExpr::new(
MathNode::Matrix {
rows: vec![
vec![
MathNode::Number {
value: "1".to_string(),
is_decimal: false,
},
MathNode::Number {
value: "2".to_string(),
is_decimal: false,
},
],
vec![
MathNode::Number {
value: "3".to_string(),
is_decimal: false,
},
MathNode::Number {
value: "4".to_string(),
is_decimal: false,
},
],
],
bracket_type: BracketType::Brackets,
},
1.0,
);
let gen = LaTeXGenerator::new();
assert_eq!(
gen.generate(&expr),
"\\begin{bmatrix} 1 & 2 \\\\ 3 & 4 \\end{bmatrix}"
);
}
}

View File

@@ -0,0 +1,408 @@
//! MathML generation from mathematical AST
//!
//! This module converts mathematical AST nodes to MathML (Mathematical Markup Language)
//! XML format for rendering in web browsers and applications.
use crate::math::ast::{BinaryOp, BracketType, LargeOpType, MathExpr, MathNode, UnaryOp};
/// MathML generator for mathematical expressions
pub struct MathMLGenerator {
/// Use presentation MathML (true) or content MathML (false)
presentation: bool,
}
impl MathMLGenerator {
/// Create a new MathML generator (presentation mode)
pub fn new() -> Self {
Self { presentation: true }
}
/// Create a content MathML generator
pub fn content() -> Self {
Self {
presentation: false,
}
}
/// Generate MathML string from a mathematical expression
pub fn generate(&self, expr: &MathExpr) -> String {
let content = self.generate_node(&expr.root);
format!(
r#"<math xmlns="http://www.w3.org/1998/Math/MathML">{}</math>"#,
content
)
}
/// Generate MathML for a single node
fn generate_node(&self, node: &MathNode) -> String {
match node {
MathNode::Symbol { value, .. } => {
format!("<mi>{}</mi>", escape_xml(value))
}
MathNode::Number { value, .. } => {
format!("<mn>{}</mn>", escape_xml(value))
}
MathNode::Binary { op, left, right } => {
let left_ml = self.generate_node(left);
let right_ml = self.generate_node(right);
let op_ml = self.binary_op_to_mathml(op);
format!("<mrow>{}<mo>{}</mo>{}</mrow>", left_ml, op_ml, right_ml)
}
MathNode::Unary { op, operand } => {
let op_ml = self.unary_op_to_mathml(op);
let operand_ml = self.generate_node(operand);
format!("<mrow><mo>{}</mo>{}</mrow>", op_ml, operand_ml)
}
MathNode::Fraction {
numerator,
denominator,
} => {
let num_ml = self.generate_node(numerator);
let den_ml = self.generate_node(denominator);
format!("<mfrac>{}{}</mfrac>", num_ml, den_ml)
}
MathNode::Radical { index, radicand } => {
let rad_ml = self.generate_node(radicand);
if let Some(idx) = index {
let idx_ml = self.generate_node(idx);
format!("<mroot>{}{}</mroot>", rad_ml, idx_ml)
} else {
format!("<msqrt>{}</msqrt>", rad_ml)
}
}
MathNode::Script {
base,
subscript,
superscript,
} => {
let base_ml = self.generate_node(base);
match (subscript, superscript) {
(Some(sub), Some(sup)) => {
let sub_ml = self.generate_node(sub);
let sup_ml = self.generate_node(sup);
format!("<msubsup>{}{}{}</msubsup>", base_ml, sub_ml, sup_ml)
}
(Some(sub), None) => {
let sub_ml = self.generate_node(sub);
format!("<msub>{}{}</msub>", base_ml, sub_ml)
}
(None, Some(sup)) => {
let sup_ml = self.generate_node(sup);
format!("<msup>{}{}</msup>", base_ml, sup_ml)
}
(None, None) => base_ml,
}
}
MathNode::Function { name, argument } => {
let name_ml = format!("<mi>{}</mi>", escape_xml(name));
let arg_ml = self.generate_node(argument);
format!("<mrow>{}<mo>&ApplyFunction;</mo>{}</mrow>", name_ml, arg_ml)
}
MathNode::Matrix { rows, bracket_type } => {
let mut content = String::new();
for row in rows {
content.push_str("<mtr>");
for elem in row {
content.push_str("<mtd>");
content.push_str(&self.generate_node(elem));
content.push_str("</mtd>");
}
content.push_str("</mtr>");
}
let (open, close) = self.bracket_to_mathml(*bracket_type);
format!(
"<mrow><mo>{}</mo><mtable>{}</mtable><mo>{}</mo></mrow>",
open, content, close
)
}
MathNode::Group {
content,
bracket_type,
} => {
let content_ml = self.generate_node(content);
let (open, close) = self.bracket_to_mathml(*bracket_type);
if *bracket_type == BracketType::None {
content_ml
} else {
format!(
"<mrow><mo>{}</mo>{}<mo>{}</mo></mrow>",
open, content_ml, close
)
}
}
MathNode::LargeOp {
op_type,
lower,
upper,
content,
} => {
let op_ml = self.large_op_to_mathml(op_type);
let content_ml = self.generate_node(content);
match (lower, upper) {
(Some(low), Some(up)) => {
let low_ml = self.generate_node(low);
let up_ml = self.generate_node(up);
format!(
"<mrow><munderover><mo>{}</mo>{}{}</munderover>{}</mrow>",
op_ml, low_ml, up_ml, content_ml
)
}
(Some(low), None) => {
let low_ml = self.generate_node(low);
format!(
"<mrow><munder><mo>{}</mo>{}</munder>{}</mrow>",
op_ml, low_ml, content_ml
)
}
(None, Some(up)) => {
let up_ml = self.generate_node(up);
format!(
"<mrow><mover><mo>{}</mo>{}</mover>{}</mrow>",
op_ml, up_ml, content_ml
)
}
(None, None) => {
format!("<mrow><mo>{}</mo>{}</mrow>", op_ml, content_ml)
}
}
}
MathNode::Sequence { elements } => {
let mut content = String::new();
for (i, elem) in elements.iter().enumerate() {
if i > 0 {
content.push_str("<mo>,</mo>");
}
content.push_str(&self.generate_node(elem));
}
format!("<mrow>{}</mrow>", content)
}
MathNode::Text { content } => {
format!("<mtext>{}</mtext>", escape_xml(content))
}
MathNode::Empty => String::new(),
}
}
/// Convert binary operator to MathML
fn binary_op_to_mathml(&self, op: &BinaryOp) -> String {
match op {
BinaryOp::Add => "+".to_string(),
BinaryOp::Subtract => "".to_string(),
BinaryOp::Multiply => "×".to_string(),
BinaryOp::Divide => "÷".to_string(),
BinaryOp::Power => "^".to_string(),
BinaryOp::Equal => "=".to_string(),
BinaryOp::NotEqual => "".to_string(),
BinaryOp::Less => "&lt;".to_string(),
BinaryOp::Greater => "&gt;".to_string(),
BinaryOp::LessEqual => "".to_string(),
BinaryOp::GreaterEqual => "".to_string(),
BinaryOp::ApproxEqual => "".to_string(),
BinaryOp::Equivalent => "".to_string(),
BinaryOp::Similar => "".to_string(),
BinaryOp::Congruent => "".to_string(),
BinaryOp::Proportional => "".to_string(),
BinaryOp::Custom(s) => s.clone(),
}
}
/// Convert unary operator to MathML
fn unary_op_to_mathml(&self, op: &UnaryOp) -> String {
match op {
UnaryOp::Plus => "+".to_string(),
UnaryOp::Minus => "".to_string(),
UnaryOp::Not => "¬".to_string(),
UnaryOp::Custom(s) => s.clone(),
}
}
/// Convert large operator to MathML
fn large_op_to_mathml(&self, op: &LargeOpType) -> &'static str {
match op {
LargeOpType::Sum => "",
LargeOpType::Product => "",
LargeOpType::Integral => "",
LargeOpType::DoubleIntegral => "",
LargeOpType::TripleIntegral => "",
LargeOpType::ContourIntegral => "",
LargeOpType::Union => "",
LargeOpType::Intersection => "",
LargeOpType::Coproduct => "",
LargeOpType::DirectSum => "",
LargeOpType::Custom(_) => "", // Default fallback
}
}
/// Convert bracket type to MathML delimiters
fn bracket_to_mathml(&self, bracket_type: BracketType) -> (&'static str, &'static str) {
match bracket_type {
BracketType::Parentheses => ("(", ")"),
BracketType::Brackets => ("[", "]"),
BracketType::Braces => ("{", "}"),
BracketType::AngleBrackets => ("", ""),
BracketType::Vertical => ("|", "|"),
BracketType::DoubleVertical => ("", ""),
BracketType::Floor => ("", ""),
BracketType::Ceiling => ("", ""),
BracketType::None => ("", ""),
}
}
}
impl Default for MathMLGenerator {
fn default() -> Self {
Self::new()
}
}
/// Escape XML special characters
fn escape_xml(s: &str) -> String {
s.replace('&', "&amp;")
.replace('<', "&lt;")
.replace('>', "&gt;")
.replace('"', "&quot;")
.replace('\'', "&apos;")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_number() {
let expr = MathExpr::new(
MathNode::Number {
value: "42".to_string(),
is_decimal: false,
},
1.0,
);
let gen = MathMLGenerator::new();
let result = gen.generate(&expr);
assert!(result.contains("<mn>42</mn>"));
}
#[test]
fn test_symbol() {
let expr = MathExpr::new(
MathNode::Symbol {
value: "x".to_string(),
unicode: Some('x'),
},
1.0,
);
let gen = MathMLGenerator::new();
let result = gen.generate(&expr);
assert!(result.contains("<mi>x</mi>"));
}
#[test]
fn test_binary_add() {
let expr = MathExpr::new(
MathNode::Binary {
op: BinaryOp::Add,
left: Box::new(MathNode::Number {
value: "1".to_string(),
is_decimal: false,
}),
right: Box::new(MathNode::Number {
value: "2".to_string(),
is_decimal: false,
}),
},
1.0,
);
let gen = MathMLGenerator::new();
let result = gen.generate(&expr);
assert!(result.contains("<mrow>"));
assert!(result.contains("<mo>+</mo>"));
}
#[test]
fn test_fraction() {
let expr = MathExpr::new(
MathNode::Fraction {
numerator: Box::new(MathNode::Number {
value: "1".to_string(),
is_decimal: false,
}),
denominator: Box::new(MathNode::Number {
value: "2".to_string(),
is_decimal: false,
}),
},
1.0,
);
let gen = MathMLGenerator::new();
let result = gen.generate(&expr);
assert!(result.contains("<mfrac>"));
}
#[test]
fn test_sqrt() {
let expr = MathExpr::new(
MathNode::Radical {
index: None,
radicand: Box::new(MathNode::Number {
value: "2".to_string(),
is_decimal: false,
}),
},
1.0,
);
let gen = MathMLGenerator::new();
let result = gen.generate(&expr);
assert!(result.contains("<msqrt>"));
}
#[test]
fn test_superscript() {
let expr = MathExpr::new(
MathNode::Script {
base: Box::new(MathNode::Symbol {
value: "x".to_string(),
unicode: Some('x'),
}),
subscript: None,
superscript: Some(Box::new(MathNode::Number {
value: "2".to_string(),
is_decimal: false,
})),
},
1.0,
);
let gen = MathMLGenerator::new();
let result = gen.generate(&expr);
assert!(result.contains("<msup>"));
}
#[test]
fn test_xml_escaping() {
assert_eq!(escape_xml("a < b"), "a &lt; b");
assert_eq!(escape_xml("x & y"), "x &amp; y");
}
}

View File

@@ -0,0 +1,246 @@
//! Mathematical expression parsing and conversion module
//!
//! This module provides functionality for parsing, representing, and converting
//! mathematical expressions between various formats including LaTeX, MathML, and AsciiMath.
//!
//! # Modules
//!
//! - `ast`: Abstract Syntax Tree definitions for mathematical expressions
//! - `symbols`: Symbol mappings between Unicode and LaTeX
//! - `latex`: LaTeX generation from AST
//! - `mathml`: MathML generation from AST
//! - `asciimath`: AsciiMath generation from AST
//! - `parser`: Expression parsing from various formats
//!
//! # Examples
//!
//! ## Parsing and converting to LaTeX
//!
//! ```no_run
//! use ruvector_scipix::math::{parse_expression, to_latex};
//!
//! let expr = parse_expression("x^2 + 2x + 1").unwrap();
//! let latex = to_latex(&expr);
//! println!("LaTeX: {}", latex);
//! ```
//!
//! ## Building an expression manually
//!
//! ```no_run
//! use ruvector_scipix::math::ast::{MathExpr, MathNode, BinaryOp};
//!
//! let expr = MathExpr::new(
//! MathNode::Binary {
//! op: BinaryOp::Add,
//! left: Box::new(MathNode::Number {
//! value: "1".to_string(),
//! is_decimal: false,
//! }),
//! right: Box::new(MathNode::Number {
//! value: "2".to_string(),
//! is_decimal: false,
//! }),
//! },
//! 1.0,
//! );
//! ```
pub mod asciimath;
pub mod ast;
pub mod latex;
pub mod mathml;
pub mod parser;
pub mod symbols;
// Re-export commonly used types
pub use asciimath::AsciiMathGenerator;
pub use ast::{BinaryOp, BracketType, LargeOpType, MathExpr, MathNode, MathVisitor, UnaryOp};
pub use latex::{LaTeXConfig, LaTeXGenerator};
pub use mathml::MathMLGenerator;
pub use parser::{parse_expression, Parser};
pub use symbols::{get_symbol, unicode_to_latex, MathSymbol, SymbolCategory};
/// Parse a mathematical expression from a string
///
/// # Arguments
///
/// * `input` - The input string to parse (LaTeX, Unicode, or mixed)
///
/// # Returns
///
/// A `Result` containing the parsed `MathExpr` or an error message
///
/// # Examples
///
/// ```no_run
/// use ruvector_scipix::math::parse_expression;
///
/// let expr = parse_expression("\\frac{1}{2}").unwrap();
/// ```
pub fn parse(input: &str) -> Result<MathExpr, String> {
parse_expression(input)
}
/// Convert a mathematical expression to LaTeX format
///
/// # Arguments
///
/// * `expr` - The mathematical expression to convert
///
/// # Returns
///
/// A LaTeX string representation of the expression
///
/// # Examples
///
/// ```no_run
/// use ruvector_scipix::math::{parse_expression, to_latex};
///
/// let expr = parse_expression("x^2").unwrap();
/// let latex = to_latex(&expr);
/// assert!(latex.contains("^"));
/// ```
pub fn to_latex(expr: &MathExpr) -> String {
LaTeXGenerator::new().generate(expr)
}
/// Convert a mathematical expression to LaTeX with custom configuration
///
/// # Arguments
///
/// * `expr` - The mathematical expression to convert
/// * `config` - LaTeX generation configuration
///
/// # Returns
///
/// A LaTeX string representation of the expression
pub fn to_latex_with_config(expr: &MathExpr, config: LaTeXConfig) -> String {
LaTeXGenerator::with_config(config).generate(expr)
}
/// Convert a mathematical expression to MathML format
///
/// # Arguments
///
/// * `expr` - The mathematical expression to convert
///
/// # Returns
///
/// A MathML XML string representation of the expression
///
/// # Examples
///
/// ```no_run
/// use ruvector_scipix::math::{parse_expression, to_mathml};
///
/// let expr = parse_expression("x^2").unwrap();
/// let mathml = to_mathml(&expr);
/// assert!(mathml.contains("<msup>"));
/// ```
pub fn to_mathml(expr: &MathExpr) -> String {
MathMLGenerator::new().generate(expr)
}
/// Convert a mathematical expression to AsciiMath format
///
/// # Arguments
///
/// * `expr` - The mathematical expression to convert
///
/// # Returns
///
/// An AsciiMath string representation of the expression
///
/// # Examples
///
/// ```no_run
/// use ruvector_scipix::math::{parse_expression, to_asciimath};
///
/// let expr = parse_expression("x^2").unwrap();
/// let asciimath = to_asciimath(&expr);
/// ```
pub fn to_asciimath(expr: &MathExpr) -> String {
AsciiMathGenerator::new().generate(expr)
}
/// Convert a mathematical expression to ASCII-only AsciiMath format
///
/// # Arguments
///
/// * `expr` - The mathematical expression to convert
///
/// # Returns
///
/// An ASCII-only AsciiMath string representation of the expression
pub fn to_asciimath_ascii_only(expr: &MathExpr) -> String {
AsciiMathGenerator::ascii_only().generate(expr)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_and_convert() {
let expr = parse("1 + 2").unwrap();
let latex = to_latex(&expr);
assert!(latex.contains("+"));
}
#[test]
fn test_fraction_conversion() {
let expr = parse("\\frac{1}{2}").unwrap();
let latex = to_latex(&expr);
assert!(latex.contains("\\frac"));
let mathml = to_mathml(&expr);
assert!(mathml.contains("<mfrac>"));
let asciimath = to_asciimath(&expr);
assert!(asciimath.contains("/"));
}
#[test]
fn test_sqrt_conversion() {
let expr = parse("\\sqrt{2}").unwrap();
let latex = to_latex(&expr);
assert!(latex.contains("\\sqrt"));
let mathml = to_mathml(&expr);
assert!(mathml.contains("<msqrt>"));
let asciimath = to_asciimath(&expr);
assert!(asciimath.contains("sqrt"));
}
#[test]
fn test_complex_expression() {
// Quadratic formula: (-b ± √(b² - 4ac)) / 2a
let expr = parse("\\frac{-b + \\sqrt{b^2 - 4*a*c}}{2*a}").unwrap();
let latex = to_latex(&expr);
assert!(latex.contains("\\frac"));
assert!(latex.contains("\\sqrt"));
let mathml = to_mathml(&expr);
assert!(mathml.contains("<mfrac>"));
assert!(mathml.contains("<msqrt>"));
}
#[test]
fn test_symbol_lookup() {
assert!(unicode_to_latex('α').is_some());
assert_eq!(unicode_to_latex('α'), Some("alpha"));
assert_eq!(unicode_to_latex('π'), Some("pi"));
assert_eq!(unicode_to_latex('∑'), Some("sum"));
}
#[test]
fn test_get_symbol() {
let sym = get_symbol('α').unwrap();
assert_eq!(sym.latex, "alpha");
assert_eq!(sym.category, SymbolCategory::Greek);
}
}

View File

@@ -0,0 +1,529 @@
//! Mathematical expression parser
//!
//! This module parses mathematical expressions from various formats
//! including LaTeX, Unicode text, and symbolic notation.
use crate::math::ast::{BinaryOp, BracketType, LargeOpType, MathExpr, MathNode, UnaryOp};
use crate::math::symbols::get_symbol;
use nom::{
branch::alt,
bytes::complete::{tag, take_while, take_while1},
character::complete::{alpha1, char, digit1, multispace0},
combinator::{map, opt, recognize},
multi::{many0, separated_list0},
sequence::{delimited, pair, preceded, tuple},
IResult,
};
/// Parser for mathematical expressions
pub struct Parser {
/// Confidence score for parsed expression
confidence: f32,
}
impl Parser {
/// Create a new parser
pub fn new() -> Self {
Self { confidence: 1.0 }
}
/// Parse a mathematical expression from string
pub fn parse(&mut self, input: &str) -> Result<MathExpr, String> {
match self.parse_expression(input) {
Ok((_, node)) => Ok(MathExpr::new(node, self.confidence)),
Err(e) => Err(format!("Parse error: {:?}", e)),
}
}
/// Parse top-level expression
fn parse_expression<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
self.parse_relational(input)
}
/// Parse relational operators (=, <, >, etc.)
fn parse_relational<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
let (input, left) = self.parse_additive(input)?;
let (input, op_right) = opt(pair(
delimited(
multispace0,
alt((
map(tag("=="), |_| BinaryOp::Equal),
map(tag("="), |_| BinaryOp::Equal),
map(tag("!="), |_| BinaryOp::NotEqual),
map(tag(""), |_| BinaryOp::NotEqual),
map(tag("<="), |_| BinaryOp::LessEqual),
map(tag(""), |_| BinaryOp::LessEqual),
map(tag(">="), |_| BinaryOp::GreaterEqual),
map(tag(""), |_| BinaryOp::GreaterEqual),
map(tag("<"), |_| BinaryOp::Less),
map(tag(">"), |_| BinaryOp::Greater),
map(tag(""), |_| BinaryOp::ApproxEqual),
map(tag(""), |_| BinaryOp::Equivalent),
map(tag(""), |_| BinaryOp::Similar),
)),
multispace0,
),
|i| self.parse_additive(i),
))(input)?;
Ok((
input,
if let Some((op, right)) = op_right {
MathNode::Binary {
op,
left: Box::new(left),
right: Box::new(right),
}
} else {
left
},
))
}
/// Parse additive operators (+, -)
fn parse_additive<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
let (input, left) = self.parse_multiplicative(input)?;
let (input, ops) = many0(pair(
delimited(
multispace0,
alt((
map(char('+'), |_| BinaryOp::Add),
map(char('-'), |_| BinaryOp::Subtract),
)),
multispace0,
),
|i| self.parse_multiplicative(i),
))(input)?;
Ok((
input,
ops.into_iter()
.fold(left, |acc, (op, right)| MathNode::Binary {
op,
left: Box::new(acc),
right: Box::new(right),
}),
))
}
/// Parse multiplicative operators (*, /, ×, ÷)
fn parse_multiplicative<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
let (input, left) = self.parse_power(input)?;
let (input, ops) = many0(pair(
delimited(
multispace0,
alt((
map(char('*'), |_| BinaryOp::Multiply),
map(char('/'), |_| BinaryOp::Divide),
map(char('×'), |_| BinaryOp::Multiply),
map(char('÷'), |_| BinaryOp::Divide),
map(tag("\\times"), |_| BinaryOp::Multiply),
map(tag("\\div"), |_| BinaryOp::Divide),
map(tag("\\cdot"), |_| BinaryOp::Multiply),
)),
multispace0,
),
|i| self.parse_power(i),
))(input)?;
Ok((
input,
ops.into_iter()
.fold(left, |acc, (op, right)| MathNode::Binary {
op,
left: Box::new(acc),
right: Box::new(right),
}),
))
}
/// Parse power operator (^)
fn parse_power<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
let (input, base) = self.parse_unary(input)?;
let (input, exp) = opt(preceded(
delimited(multispace0, char('^'), multispace0),
|i| self.parse_unary(i),
))(input)?;
Ok((
input,
if let Some(exponent) = exp {
MathNode::Binary {
op: BinaryOp::Power,
left: Box::new(base),
right: Box::new(exponent),
}
} else {
base
},
))
}
/// Parse unary operators (+, -)
fn parse_unary<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
alt((
map(
pair(
delimited(
multispace0,
alt((
map(char('+'), |_| UnaryOp::Plus),
map(char('-'), |_| UnaryOp::Minus),
)),
multispace0,
),
|i| self.parse_script(i),
),
|(op, operand)| MathNode::Unary {
op,
operand: Box::new(operand),
},
),
|i| self.parse_script(i),
))(input)
}
/// Parse subscript/superscript
fn parse_script<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
let (input, base) = self.parse_primary(input)?;
let (input, sub) = opt(preceded(char('_'), |i| self.parse_script_content(i)))(input)?;
let (input, sup) = opt(preceded(char('^'), |i| self.parse_script_content(i)))(input)?;
Ok((
input,
if sub.is_some() || sup.is_some() {
MathNode::Script {
base: Box::new(base),
subscript: sub.map(Box::new),
superscript: sup.map(Box::new),
}
} else {
base
},
))
}
/// Parse script content (single char or braced expression)
fn parse_script_content<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
alt((
delimited(char('{'), |i| self.parse_expression(i), char('}')),
map(recognize(alpha1), |s: &str| MathNode::Symbol {
value: s.to_string(),
unicode: s.chars().next(),
}),
map(digit1, |s: &str| MathNode::Number {
value: s.to_string(),
is_decimal: false,
}),
))(input)
}
/// Parse primary expressions (atoms)
fn parse_primary<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
delimited(
multispace0,
alt((
|i| self.parse_function(i),
|i| self.parse_fraction(i),
|i| self.parse_radical(i),
|i| self.parse_large_op(i),
|i| self.parse_greek(i),
|i| self.parse_number(i),
|i| self.parse_symbol(i),
|i| self.parse_grouped(i),
)),
multispace0,
)(input)
}
/// Parse fraction (\frac{a}{b})
fn parse_fraction<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
let (input, _) = tag("\\frac")(input)?;
let (input, num) = delimited(char('{'), |i| self.parse_expression(i), char('}'))(input)?;
let (input, den) = delimited(char('{'), |i| self.parse_expression(i), char('}'))(input)?;
Ok((
input,
MathNode::Fraction {
numerator: Box::new(num),
denominator: Box::new(den),
},
))
}
/// Parse radical (\sqrt[n]{x})
fn parse_radical<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
let (input, _) = tag("\\sqrt")(input)?;
let (input, index) = opt(delimited(
char('['),
|i| self.parse_expression(i),
char(']'),
))(input)?;
let (input, radicand) =
delimited(char('{'), |i| self.parse_expression(i), char('}'))(input)?;
Ok((
input,
MathNode::Radical {
index: index.map(Box::new),
radicand: Box::new(radicand),
},
))
}
/// Parse large operators (sum, integral, etc.)
fn parse_large_op<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
let (input, op_type) = alt((
map(tag("\\sum"), |_| LargeOpType::Sum),
map(tag("\\prod"), |_| LargeOpType::Product),
map(tag("\\int"), |_| LargeOpType::Integral),
map(tag("\\iint"), |_| LargeOpType::DoubleIntegral),
map(tag("\\iiint"), |_| LargeOpType::TripleIntegral),
map(tag("\\oint"), |_| LargeOpType::ContourIntegral),
map(tag(""), |_| LargeOpType::Sum),
map(tag(""), |_| LargeOpType::Product),
map(tag(""), |_| LargeOpType::Integral),
))(input)?;
let (input, lower) = opt(preceded(
char('_'),
alt((
delimited(char('{'), |i| self.parse_expression(i), char('}')),
|i| self.parse_primary(i),
)),
))(input)?;
let (input, upper) = opt(preceded(
char('^'),
alt((
delimited(char('{'), |i| self.parse_expression(i), char('}')),
|i| self.parse_primary(i),
)),
))(input)?;
let (input, content) = self.parse_primary(input)?;
Ok((
input,
MathNode::LargeOp {
op_type,
lower: lower.map(Box::new),
upper: upper.map(Box::new),
content: Box::new(content),
},
))
}
/// Parse function (sin, cos, etc.)
fn parse_function<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
let (input, _) = char('\\')(input)?;
let (input, name) = alpha1(input)?;
let (input, _) = multispace0(input)?;
let (input, arg) = self.parse_primary(input)?;
Ok((
input,
MathNode::Function {
name: name.to_string(),
argument: Box::new(arg),
},
))
}
/// Parse Greek letter
fn parse_greek<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
let (input, _) = char('\\')(input)?;
let (input, name) = alpha1(input)?;
// Convert LaTeX name to Unicode if possible
let unicode = match name {
"alpha" => Some('α'),
"beta" => Some('β'),
"gamma" => Some('γ'),
"delta" => Some('δ'),
"epsilon" => Some('ε'),
"pi" => Some('π'),
"theta" => Some('θ'),
"lambda" => Some('λ'),
"mu" => Some('μ'),
"sigma" => Some('σ'),
_ => None,
};
Ok((
input,
MathNode::Symbol {
value: name.to_string(),
unicode,
},
))
}
/// Parse number
fn parse_number<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
let (input, num_str) = recognize(pair(digit1, opt(pair(char('.'), digit1))))(input)?;
let is_decimal = num_str.contains('.');
Ok((
input,
MathNode::Number {
value: num_str.to_string(),
is_decimal,
},
))
}
/// Parse symbol (variable)
fn parse_symbol<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
map(take_while1(|c: char| c.is_alphabetic()), |s: &str| {
let c = s.chars().next();
MathNode::Symbol {
value: s.to_string(),
unicode: c,
}
})(input)
}
/// Parse grouped expression (parentheses)
fn parse_grouped<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
delimited(char('('), |i| self.parse_expression(i), char(')'))(input)
}
}
impl Default for Parser {
fn default() -> Self {
Self::new()
}
}
/// Parse a mathematical expression from string
pub fn parse_expression(input: &str) -> Result<MathExpr, String> {
Parser::new().parse(input)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_number() {
let expr = parse_expression("42").unwrap();
match expr.root {
MathNode::Number { value, .. } => assert_eq!(value, "42"),
_ => panic!("Expected Number node"),
}
}
#[test]
fn test_parse_addition() {
let expr = parse_expression("1 + 2").unwrap();
match expr.root {
MathNode::Binary { op, .. } => assert_eq!(op, BinaryOp::Add),
_ => panic!("Expected Binary node"),
}
}
#[test]
fn test_parse_multiplication() {
let expr = parse_expression("3 * 4").unwrap();
match expr.root {
MathNode::Binary { op, .. } => assert_eq!(op, BinaryOp::Multiply),
_ => panic!("Expected Binary node"),
}
}
#[test]
fn test_parse_precedence() {
let expr = parse_expression("1 + 2 * 3").unwrap();
// Should parse as 1 + (2 * 3)
match expr.root {
MathNode::Binary {
op: BinaryOp::Add,
left,
right,
} => {
assert!(matches!(*left, MathNode::Number { .. }));
assert!(matches!(
*right,
MathNode::Binary {
op: BinaryOp::Multiply,
..
}
));
}
_ => panic!("Expected Add with Multiply on right"),
}
}
#[test]
fn test_parse_power() {
let expr = parse_expression("x^2").unwrap();
match expr.root {
MathNode::Binary { op, .. } => assert_eq!(op, BinaryOp::Power),
_ => panic!("Expected Binary node with power"),
}
}
#[test]
fn test_parse_fraction() {
let expr = parse_expression("\\frac{1}{2}").unwrap();
match expr.root {
MathNode::Fraction { .. } => {}
_ => panic!("Expected Fraction node"),
}
}
#[test]
fn test_parse_sqrt() {
let expr = parse_expression("\\sqrt{2}").unwrap();
match expr.root {
MathNode::Radical { index, .. } => assert!(index.is_none()),
_ => panic!("Expected Radical node"),
}
}
#[test]
fn test_parse_nth_root() {
let expr = parse_expression("\\sqrt[3]{8}").unwrap();
match expr.root {
MathNode::Radical { index, .. } => assert!(index.is_some()),
_ => panic!("Expected Radical node with index"),
}
}
#[test]
fn test_parse_subscript() {
let expr = parse_expression("a_n").unwrap();
match expr.root {
MathNode::Script { subscript, .. } => assert!(subscript.is_some()),
_ => panic!("Expected Script node"),
}
}
#[test]
fn test_parse_superscript() {
let expr = parse_expression("x^2").unwrap();
match expr.root {
MathNode::Binary { op, .. } => assert_eq!(op, BinaryOp::Power),
_ => panic!("Expected power operation"),
}
}
#[test]
fn test_parse_sum() {
let expr = parse_expression("\\sum_{i=1}^{n} i").unwrap();
match expr.root {
MathNode::LargeOp { op_type, .. } => assert_eq!(op_type, LargeOpType::Sum),
_ => panic!("Expected LargeOp node"),
}
}
#[test]
fn test_parse_complex() {
let expr = parse_expression("\\frac{-b + \\sqrt{b^2 - 4ac}}{2a}").unwrap();
match expr.root {
MathNode::Fraction { .. } => {}
_ => panic!("Expected Fraction node"),
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,384 @@
//! Confidence Scoring Module
//!
//! This module provides confidence scoring and calibration for OCR results.
//! It includes per-character confidence calculation and aggregation methods.
use super::Result;
use std::collections::HashMap;
use tracing::debug;
/// Calculate confidence score for a single character prediction
///
/// # Arguments
/// * `logits` - Raw logits from the model for this character position
///
/// # Returns
/// Confidence score between 0.0 and 1.0
pub fn calculate_confidence(logits: &[f32]) -> f32 {
if logits.is_empty() {
return 0.0;
}
// Apply softmax to get probabilities
let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp_sum: f32 = logits.iter().map(|&x| (x - max_logit).exp()).sum();
// Return the maximum probability
let max_prob = logits
.iter()
.map(|&x| (x - max_logit).exp() / exp_sum)
.fold(0.0f32, |a, b| a.max(b));
max_prob.clamp(0.0, 1.0)
}
/// Aggregate multiple confidence scores into a single score
///
/// # Arguments
/// * `confidences` - Individual confidence scores
///
/// # Returns
/// Aggregated confidence score using geometric mean
pub fn aggregate_confidence(confidences: &[f32]) -> f32 {
if confidences.is_empty() {
return 0.0;
}
// Use geometric mean for aggregation (more conservative than arithmetic mean)
let product: f32 = confidences.iter().product();
let n = confidences.len() as f32;
product.powf(1.0 / n).clamp(0.0, 1.0)
}
/// Alternative aggregation using arithmetic mean
pub fn aggregate_confidence_mean(confidences: &[f32]) -> f32 {
if confidences.is_empty() {
return 0.0;
}
let sum: f32 = confidences.iter().sum();
(sum / confidences.len() as f32).clamp(0.0, 1.0)
}
/// Alternative aggregation using minimum (most conservative)
pub fn aggregate_confidence_min(confidences: &[f32]) -> f32 {
confidences
.iter()
.fold(1.0f32, |a, &b| a.min(b))
.clamp(0.0, 1.0)
}
/// Alternative aggregation using harmonic mean
pub fn aggregate_confidence_harmonic(confidences: &[f32]) -> f32 {
if confidences.is_empty() {
return 0.0;
}
let sum_reciprocals: f32 = confidences.iter().map(|&c| 1.0 / c.max(0.001)).sum();
let n = confidences.len() as f32;
(n / sum_reciprocals).clamp(0.0, 1.0)
}
/// Confidence calibrator using isotonic regression
///
/// This calibrator learns a mapping from raw confidence scores to calibrated
/// probabilities using historical data.
pub struct ConfidenceCalibrator {
/// Calibration mapping: raw_score -> calibrated_score
calibration_map: HashMap<u8, f32>, // Use u8 for binned scores (0-100)
/// Whether the calibrator has been trained
is_trained: bool,
}
impl ConfidenceCalibrator {
/// Create a new, untrained calibrator
pub fn new() -> Self {
Self {
calibration_map: HashMap::new(),
is_trained: false,
}
}
/// Train the calibrator on labeled data
///
/// # Arguments
/// * `predictions` - Raw confidence scores from the model
/// * `ground_truth` - Binary labels (1.0 if correct, 0.0 if incorrect)
pub fn train(&mut self, predictions: &[f32], ground_truth: &[f32]) -> Result<()> {
debug!(
"Training confidence calibrator on {} samples",
predictions.len()
);
if predictions.len() != ground_truth.len() {
return Err(super::OcrError::InvalidConfig(
"Predictions and ground truth must have same length".to_string(),
));
}
if predictions.is_empty() {
return Err(super::OcrError::InvalidConfig(
"Cannot train on empty data".to_string(),
));
}
// Bin the scores (0.0-1.0 -> 0-100)
let mut bins: HashMap<u8, Vec<f32>> = HashMap::new();
for (&pred, &truth) in predictions.iter().zip(ground_truth.iter()) {
let bin = (pred * 100.0).clamp(0.0, 100.0) as u8;
bins.entry(bin).or_insert_with(Vec::new).push(truth);
}
// Calculate mean accuracy for each bin
self.calibration_map.clear();
for (bin, truths) in bins {
let mean_accuracy = truths.iter().sum::<f32>() / truths.len() as f32;
self.calibration_map.insert(bin, mean_accuracy);
}
// Perform isotonic regression (simplified version)
self.enforce_monotonicity();
self.is_trained = true;
debug!(
"Calibrator trained with {} bins",
self.calibration_map.len()
);
Ok(())
}
/// Enforce monotonicity constraint (isotonic regression)
fn enforce_monotonicity(&mut self) {
let mut sorted_bins: Vec<_> = self.calibration_map.iter().collect();
sorted_bins.sort_by_key(|(bin, _)| *bin);
// Simple isotonic regression: ensure calibrated scores are non-decreasing
let mut adjusted = HashMap::new();
let mut prev_value = 0.0;
for (&bin, &value) in sorted_bins {
let adjusted_value = value.max(prev_value);
adjusted.insert(bin, adjusted_value);
prev_value = adjusted_value;
}
self.calibration_map = adjusted;
}
/// Calibrate a raw confidence score
pub fn calibrate(&self, raw_score: f32) -> f32 {
if !self.is_trained {
// If not trained, return raw score
return raw_score.clamp(0.0, 1.0);
}
let bin = (raw_score * 100.0).clamp(0.0, 100.0) as u8;
// Look up calibrated score, or interpolate
if let Some(&calibrated) = self.calibration_map.get(&bin) {
return calibrated;
}
// Interpolate between nearest bins
self.interpolate(bin)
}
/// Interpolate calibrated score for a bin without direct mapping
fn interpolate(&self, target_bin: u8) -> f32 {
let mut lower = None;
let mut upper = None;
for &bin in self.calibration_map.keys() {
if bin < target_bin {
lower = Some(lower.map_or(bin, |l: u8| l.max(bin)));
} else if bin > target_bin {
upper = Some(upper.map_or(bin, |u: u8| u.min(bin)));
}
}
match (lower, upper) {
(Some(l), Some(u)) => {
let l_val = self.calibration_map[&l];
let u_val = self.calibration_map[&u];
let alpha = (target_bin - l) as f32 / (u - l) as f32;
l_val + alpha * (u_val - l_val)
}
(Some(l), None) => self.calibration_map[&l],
(None, Some(u)) => self.calibration_map[&u],
(None, None) => target_bin as f32 / 100.0, // Fallback
}
}
/// Check if the calibrator is trained
pub fn is_trained(&self) -> bool {
self.is_trained
}
/// Reset the calibrator
pub fn reset(&mut self) {
self.calibration_map.clear();
self.is_trained = false;
}
}
impl Default for ConfidenceCalibrator {
fn default() -> Self {
Self::new()
}
}
/// Calculate Expected Calibration Error (ECE)
///
/// Measures the difference between predicted confidence and actual accuracy
pub fn calculate_ece(predictions: &[f32], ground_truth: &[f32], n_bins: usize) -> f32 {
if predictions.len() != ground_truth.len() || predictions.is_empty() {
return 0.0;
}
let mut bins: Vec<Vec<(f32, f32)>> = vec![Vec::new(); n_bins];
// Assign predictions to bins
for (&pred, &truth) in predictions.iter().zip(ground_truth.iter()) {
let bin_idx = ((pred * n_bins as f32) as usize).min(n_bins - 1);
bins[bin_idx].push((pred, truth));
}
// Calculate ECE
let mut ece = 0.0;
let total = predictions.len() as f32;
for bin in bins {
if bin.is_empty() {
continue;
}
let bin_size = bin.len() as f32;
let avg_confidence: f32 = bin.iter().map(|(p, _)| p).sum::<f32>() / bin_size;
let avg_accuracy: f32 = bin.iter().map(|(_, t)| t).sum::<f32>() / bin_size;
ece += (bin_size / total) * (avg_confidence - avg_accuracy).abs();
}
ece
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calculate_confidence() {
let logits = vec![1.0, 5.0, 2.0, 1.0];
let conf = calculate_confidence(&logits);
assert!(conf > 0.5);
assert!(conf <= 1.0);
}
#[test]
fn test_calculate_confidence_empty() {
let logits: Vec<f32> = vec![];
let conf = calculate_confidence(&logits);
assert_eq!(conf, 0.0);
}
#[test]
fn test_aggregate_confidence() {
let confidences = vec![0.9, 0.8, 0.95, 0.85];
let agg = aggregate_confidence(&confidences);
assert!(agg > 0.0 && agg <= 1.0);
assert!(agg < 0.9); // Geometric mean should be less than max
}
#[test]
fn test_aggregate_confidence_mean() {
let confidences = vec![0.8, 0.9, 0.7];
let mean = aggregate_confidence_mean(&confidences);
assert_eq!(mean, 0.8); // (0.8 + 0.9 + 0.7) / 3
}
#[test]
fn test_aggregate_confidence_min() {
let confidences = vec![0.9, 0.7, 0.95];
let min = aggregate_confidence_min(&confidences);
assert_eq!(min, 0.7);
}
#[test]
fn test_aggregate_confidence_harmonic() {
let confidences = vec![0.5, 0.5];
let harmonic = aggregate_confidence_harmonic(&confidences);
assert_eq!(harmonic, 0.5);
}
#[test]
fn test_calibrator_training() {
let mut calibrator = ConfidenceCalibrator::new();
assert!(!calibrator.is_trained());
let predictions = vec![0.9, 0.8, 0.7, 0.6, 0.5];
let ground_truth = vec![1.0, 1.0, 0.0, 1.0, 0.0];
let result = calibrator.train(&predictions, &ground_truth);
assert!(result.is_ok());
assert!(calibrator.is_trained());
}
#[test]
fn test_calibrator_calibrate() {
let mut calibrator = ConfidenceCalibrator::new();
// Before training, should return raw score
assert_eq!(calibrator.calibrate(0.8), 0.8);
// Train with some data
let predictions = vec![0.9, 0.9, 0.8, 0.8, 0.7, 0.7];
let ground_truth = vec![1.0, 1.0, 1.0, 0.0, 0.0, 0.0];
calibrator.train(&predictions, &ground_truth).unwrap();
// After training, should return calibrated score
let calibrated = calibrator.calibrate(0.85);
assert!(calibrated >= 0.0 && calibrated <= 1.0);
}
#[test]
fn test_calibrator_reset() {
let mut calibrator = ConfidenceCalibrator::new();
let predictions = vec![0.9, 0.8];
let ground_truth = vec![1.0, 0.0];
calibrator.train(&predictions, &ground_truth).unwrap();
assert!(calibrator.is_trained());
calibrator.reset();
assert!(!calibrator.is_trained());
}
#[test]
fn test_calculate_ece() {
let predictions = vec![0.9, 0.7, 0.6, 0.8];
let ground_truth = vec![1.0, 1.0, 0.0, 1.0];
let ece = calculate_ece(&predictions, &ground_truth, 3);
assert!(ece >= 0.0 && ece <= 1.0);
}
#[test]
fn test_calibrator_monotonicity() {
let mut calibrator = ConfidenceCalibrator::new();
// Create data that would violate monotonicity
let predictions = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9];
let ground_truth = vec![0.2, 0.3, 0.2, 0.5, 0.4, 0.7, 0.8, 0.9, 1.0];
calibrator.train(&predictions, &ground_truth).unwrap();
// Check monotonicity
let score1 = calibrator.calibrate(0.3);
let score2 = calibrator.calibrate(0.5);
let score3 = calibrator.calibrate(0.7);
assert!(score2 >= score1, "Calibrated scores should be monotonic");
assert!(score3 >= score2, "Calibrated scores should be monotonic");
}
}

View File

@@ -0,0 +1,441 @@
//! Output Decoding Module
//!
//! This module provides various decoding strategies for converting
//! model output logits into text strings.
use super::{OcrError, Result};
use std::collections::HashMap;
use std::sync::Arc;
use tracing::debug;
/// Decoder trait for converting logits to text
pub trait Decoder: Send + Sync {
/// Decode logits to text
fn decode(&self, logits: &[Vec<f32>]) -> Result<String>;
/// Decode with confidence scores per character
fn decode_with_confidence(&self, logits: &[Vec<f32>]) -> Result<(String, Vec<f32>)> {
// Default implementation just returns uniform confidence
let text = self.decode(logits)?;
let confidences = vec![1.0; text.len()];
Ok((text, confidences))
}
}
/// Vocabulary mapping for character recognition
#[derive(Debug, Clone)]
pub struct Vocabulary {
/// Index to character mapping
idx_to_char: HashMap<usize, char>,
/// Character to index mapping
char_to_idx: HashMap<char, usize>,
/// Blank token index for CTC
blank_idx: usize,
}
impl Vocabulary {
/// Create a new vocabulary
pub fn new(chars: Vec<char>, blank_idx: usize) -> Self {
let idx_to_char: HashMap<usize, char> =
chars.iter().enumerate().map(|(i, &c)| (i, c)).collect();
let char_to_idx: HashMap<char, usize> =
chars.iter().enumerate().map(|(i, &c)| (c, i)).collect();
Self {
idx_to_char,
char_to_idx,
blank_idx,
}
}
/// Get character by index
pub fn get_char(&self, idx: usize) -> Option<char> {
self.idx_to_char.get(&idx).copied()
}
/// Get index by character
pub fn get_idx(&self, ch: char) -> Option<usize> {
self.char_to_idx.get(&ch).copied()
}
/// Get blank token index
pub fn blank_idx(&self) -> usize {
self.blank_idx
}
/// Get vocabulary size
pub fn size(&self) -> usize {
self.idx_to_char.len()
}
}
impl Default for Vocabulary {
fn default() -> Self {
// Default vocabulary: lowercase letters + digits + space + blank
let mut chars = Vec::new();
// Add lowercase letters
for c in 'a'..='z' {
chars.push(c);
}
// Add digits
for c in '0'..='9' {
chars.push(c);
}
// Add space
chars.push(' ');
// Blank token is at the end
let blank_idx = chars.len();
Self::new(chars, blank_idx)
}
}
/// Greedy decoder - selects the character with highest probability at each step
pub struct GreedyDecoder {
vocabulary: Arc<Vocabulary>,
}
impl GreedyDecoder {
/// Create a new greedy decoder
pub fn new(vocabulary: Arc<Vocabulary>) -> Self {
Self { vocabulary }
}
/// Find the index with maximum value in a slice
fn argmax(values: &[f32]) -> usize {
values
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx)
.unwrap_or(0)
}
}
impl Decoder for GreedyDecoder {
fn decode(&self, logits: &[Vec<f32>]) -> Result<String> {
debug!("Greedy decoding {} frames", logits.len());
let mut result = String::new();
let mut prev_idx = None;
for frame_logits in logits {
let idx = Self::argmax(frame_logits);
// Skip blank tokens and repeated characters
if idx != self.vocabulary.blank_idx() && Some(idx) != prev_idx {
if let Some(ch) = self.vocabulary.get_char(idx) {
result.push(ch);
}
}
prev_idx = Some(idx);
}
Ok(result)
}
fn decode_with_confidence(&self, logits: &[Vec<f32>]) -> Result<(String, Vec<f32>)> {
let mut result = String::new();
let mut confidences = Vec::new();
let mut prev_idx = None;
for frame_logits in logits {
let idx = Self::argmax(frame_logits);
let confidence = softmax_max(frame_logits);
// Skip blank tokens and repeated characters
if idx != self.vocabulary.blank_idx() && Some(idx) != prev_idx {
if let Some(ch) = self.vocabulary.get_char(idx) {
result.push(ch);
confidences.push(confidence);
}
}
prev_idx = Some(idx);
}
Ok((result, confidences))
}
}
/// Beam search decoder - maintains top-k hypotheses for better accuracy
pub struct BeamSearchDecoder {
vocabulary: Arc<Vocabulary>,
beam_width: usize,
}
impl BeamSearchDecoder {
/// Create a new beam search decoder
pub fn new(vocabulary: Arc<Vocabulary>, beam_width: usize) -> Self {
Self {
vocabulary,
beam_width: beam_width.max(1),
}
}
/// Get beam width
pub fn beam_width(&self) -> usize {
self.beam_width
}
}
impl Decoder for BeamSearchDecoder {
fn decode(&self, logits: &[Vec<f32>]) -> Result<String> {
debug!(
"Beam search decoding {} frames (beam_width: {})",
logits.len(),
self.beam_width
);
if logits.is_empty() {
return Ok(String::new());
}
// Initialize beams: (text, score, last_idx)
let mut beams: Vec<(String, f32, Option<usize>)> = vec![(String::new(), 0.0, None)];
for frame_logits in logits {
let mut new_beams = Vec::new();
for (text, score, last_idx) in &beams {
// Get top-k predictions for this frame
let mut indexed_logits: Vec<(usize, f32)> = frame_logits
.iter()
.enumerate()
.map(|(i, &v)| (i, v))
.collect();
indexed_logits.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
// Expand each beam with top-k predictions
for (idx, logit) in indexed_logits.iter().take(self.beam_width) {
let new_score = score + logit;
// Skip blank tokens
if *idx == self.vocabulary.blank_idx() {
new_beams.push((text.clone(), new_score, Some(*idx)));
continue;
}
// Skip repeated characters (CTC collapse)
if Some(*idx) == *last_idx {
new_beams.push((text.clone(), new_score, Some(*idx)));
continue;
}
// Add character to beam
if let Some(ch) = self.vocabulary.get_char(*idx) {
let mut new_text = text.clone();
new_text.push(ch);
new_beams.push((new_text, new_score, Some(*idx)));
}
}
}
// Keep top beam_width beams
new_beams.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
new_beams.truncate(self.beam_width);
beams = new_beams;
}
// Return the best beam
Ok(beams
.first()
.map(|(text, _, _)| text.clone())
.unwrap_or_default())
}
}
/// CTC (Connectionist Temporal Classification) decoder
pub struct CTCDecoder {
vocabulary: Arc<Vocabulary>,
}
impl CTCDecoder {
/// Create a new CTC decoder
pub fn new(vocabulary: Arc<Vocabulary>) -> Self {
Self { vocabulary }
}
/// Collapse repeated characters and remove blanks
fn collapse_repeats(&self, indices: &[usize]) -> Vec<usize> {
let mut result = Vec::new();
let mut prev_idx = None;
for &idx in indices {
// Skip blanks
if idx == self.vocabulary.blank_idx() {
prev_idx = Some(idx);
continue;
}
// Skip repeats
if Some(idx) != prev_idx {
result.push(idx);
}
prev_idx = Some(idx);
}
result
}
}
impl Decoder for CTCDecoder {
fn decode(&self, logits: &[Vec<f32>]) -> Result<String> {
debug!("CTC decoding {} frames", logits.len());
// Get best path (greedy)
let indices: Vec<usize> = logits
.iter()
.map(|frame| GreedyDecoder::argmax(frame))
.collect();
// Collapse repeats and remove blanks
let collapsed = self.collapse_repeats(&indices);
// Convert to text
let text: String = collapsed
.iter()
.filter_map(|&idx| self.vocabulary.get_char(idx))
.collect();
Ok(text)
}
fn decode_with_confidence(&self, logits: &[Vec<f32>]) -> Result<(String, Vec<f32>)> {
let indices: Vec<usize> = logits
.iter()
.map(|frame| GreedyDecoder::argmax(frame))
.collect();
let confidences: Vec<f32> = logits.iter().map(|frame| softmax_max(frame)).collect();
let collapsed = self.collapse_repeats(&indices);
let text: String = collapsed
.iter()
.filter_map(|&idx| self.vocabulary.get_char(idx))
.collect();
// Map confidences to non-collapsed positions
let mut result_confidences = Vec::new();
let mut prev_idx = None;
let mut confidence_idx = 0;
for &idx in &indices {
if idx != self.vocabulary.blank_idx() && Some(idx) != prev_idx {
if confidence_idx < confidences.len() {
result_confidences.push(confidences[confidence_idx]);
}
}
confidence_idx += 1;
prev_idx = Some(idx);
}
Ok((text, result_confidences))
}
}
/// Calculate softmax and return max probability
fn softmax_max(logits: &[f32]) -> f32 {
if logits.is_empty() {
return 0.0;
}
let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp_sum: f32 = logits.iter().map(|&x| (x - max_logit).exp()).sum();
let max_exp = (logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)) - max_logit).exp();
max_exp / exp_sum
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_vocabulary() -> Arc<Vocabulary> {
Arc::new(Vocabulary::default())
}
#[test]
fn test_vocabulary_default() {
let vocab = Vocabulary::default();
assert!(vocab.size() > 0);
assert_eq!(vocab.get_char(0), Some('a'));
assert_eq!(vocab.get_idx('a'), Some(0));
}
#[test]
fn test_greedy_decoder() {
let vocab = create_test_vocabulary();
let decoder = GreedyDecoder::new(vocab.clone());
// Mock logits for "hi"
let h_idx = vocab.get_idx('h').unwrap();
let i_idx = vocab.get_idx('i').unwrap();
let blank = vocab.blank_idx();
let mut logits = vec![
vec![0.0; vocab.size() + 1],
vec![0.0; vocab.size() + 1],
vec![0.0; vocab.size() + 1],
];
logits[0][h_idx] = 10.0;
logits[1][blank] = 10.0;
logits[2][i_idx] = 10.0;
let result = decoder.decode(&logits).unwrap();
assert_eq!(result, "hi");
}
#[test]
fn test_beam_search_decoder() {
let vocab = create_test_vocabulary();
let decoder = BeamSearchDecoder::new(vocab.clone(), 3);
assert_eq!(decoder.beam_width(), 3);
let logits = vec![vec![0.0; vocab.size() + 1]; 5];
let result = decoder.decode(&logits);
assert!(result.is_ok());
}
#[test]
fn test_ctc_decoder() {
let vocab = create_test_vocabulary();
let decoder = CTCDecoder::new(vocab.clone());
// Test collapse repeats
let a_idx = vocab.get_idx('a').unwrap();
let b_idx = vocab.get_idx('b').unwrap();
let blank = vocab.blank_idx();
let indices = vec![a_idx, a_idx, blank, b_idx, b_idx, b_idx];
let collapsed = decoder.collapse_repeats(&indices);
assert_eq!(collapsed, vec![a_idx, b_idx]);
}
#[test]
fn test_softmax_max() {
let logits = vec![1.0, 2.0, 3.0, 2.0, 1.0];
let max_prob = softmax_max(&logits);
assert!(max_prob > 0.0 && max_prob <= 1.0);
assert!(max_prob > 0.5); // The max should have high probability
}
#[test]
fn test_empty_logits() {
let vocab = create_test_vocabulary();
let decoder = GreedyDecoder::new(vocab);
let empty_logits: Vec<Vec<f32>> = vec![];
let result = decoder.decode(&empty_logits).unwrap();
assert_eq!(result, "");
}
}

View File

@@ -0,0 +1,363 @@
//! OCR Engine Implementation
//!
//! This module provides the main OcrEngine for orchestrating OCR operations.
//! It handles model loading, inference coordination, and result assembly.
use super::{
confidence::aggregate_confidence,
decoder::{BeamSearchDecoder, CTCDecoder, Decoder, GreedyDecoder, Vocabulary},
inference::{DetectionResult, InferenceEngine, RecognitionResult},
models::{ModelHandle, ModelRegistry},
Character, DecoderType, OcrError, OcrOptions, OcrResult, RegionType, Result, TextRegion,
};
use parking_lot::RwLock;
use std::sync::Arc;
use std::time::Instant;
use tracing::{debug, info, warn};
/// OCR processor trait for custom implementations
pub trait OcrProcessor: Send + Sync {
/// Process an image and return OCR results
fn process(&self, image_data: &[u8], options: &OcrOptions) -> Result<OcrResult>;
/// Batch process multiple images
fn process_batch(&self, images: &[&[u8]], options: &OcrOptions) -> Result<Vec<OcrResult>>;
}
/// Main OCR Engine with thread-safe model management
pub struct OcrEngine {
/// Model registry for loading and caching models
registry: Arc<RwLock<ModelRegistry>>,
/// Inference engine for running ONNX models
inference: Arc<InferenceEngine>,
/// Default OCR options
default_options: OcrOptions,
/// Vocabulary for decoding
vocabulary: Arc<Vocabulary>,
/// Whether the engine is warmed up
warmed_up: Arc<RwLock<bool>>,
}
impl OcrEngine {
/// Create a new OCR engine with default models
///
/// # Example
///
/// ```no_run
/// # use ruvector_scipix::ocr::OcrEngine;
/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
/// let engine = OcrEngine::new().await?;
/// # Ok(())
/// # }
/// ```
pub async fn new() -> Result<Self> {
Self::with_options(OcrOptions::default()).await
}
/// Create a new OCR engine with custom options
pub async fn with_options(options: OcrOptions) -> Result<Self> {
info!("Initializing OCR engine with options: {:?}", options);
// Initialize model registry
let registry = Arc::new(RwLock::new(ModelRegistry::new()));
// Load default models (in production, these would be downloaded/cached)
debug!("Loading detection model...");
let detection_model = registry.write().load_detection_model().await.map_err(|e| {
OcrError::ModelLoading(format!("Failed to load detection model: {}", e))
})?;
debug!("Loading recognition model...");
let recognition_model = registry
.write()
.load_recognition_model()
.await
.map_err(|e| {
OcrError::ModelLoading(format!("Failed to load recognition model: {}", e))
})?;
let math_model =
if options.enable_math {
debug!("Loading math recognition model...");
Some(registry.write().load_math_model().await.map_err(|e| {
OcrError::ModelLoading(format!("Failed to load math model: {}", e))
})?)
} else {
None
};
// Create inference engine
let inference = Arc::new(InferenceEngine::new(
detection_model,
recognition_model,
math_model,
options.use_gpu,
)?);
// Load vocabulary
let vocabulary = Arc::new(Vocabulary::default());
let engine = Self {
registry,
inference,
default_options: options,
vocabulary,
warmed_up: Arc::new(RwLock::new(false)),
};
info!("OCR engine initialized successfully");
Ok(engine)
}
/// Warm up the engine by running a dummy inference
///
/// This helps reduce latency for the first real inference by initializing
/// all ONNX runtime resources.
pub async fn warmup(&self) -> Result<()> {
if *self.warmed_up.read() {
debug!("Engine already warmed up, skipping");
return Ok(());
}
info!("Warming up OCR engine...");
let start = Instant::now();
// Create a small dummy image (100x100 black image)
let dummy_image = vec![0u8; 100 * 100 * 3];
// Run a dummy inference
let _ = self.recognize(&dummy_image).await;
*self.warmed_up.write() = true;
info!("Engine warmup completed in {:?}", start.elapsed());
Ok(())
}
/// Recognize text in an image using default options
pub async fn recognize(&self, image_data: &[u8]) -> Result<OcrResult> {
self.recognize_with_options(image_data, &self.default_options)
.await
}
/// Recognize text in an image with custom options
pub async fn recognize_with_options(
&self,
image_data: &[u8],
options: &OcrOptions,
) -> Result<OcrResult> {
let start = Instant::now();
debug!("Starting OCR recognition");
// Step 1: Run text detection
debug!("Running text detection...");
let detection_results = self
.inference
.run_detection(image_data, options.detection_threshold)
.await?;
debug!("Detected {} regions", detection_results.len());
if detection_results.is_empty() {
warn!("No text regions detected");
return Ok(OcrResult {
text: String::new(),
confidence: 0.0,
regions: vec![],
has_math: false,
processing_time_ms: start.elapsed().as_millis() as u64,
});
}
// Step 2: Run recognition on each detected region
debug!("Running text recognition...");
let mut text_regions = Vec::new();
let mut has_math = false;
for detection in detection_results {
// Determine region type
let region_type = if options.enable_math && detection.is_math_likely {
has_math = true;
RegionType::Math
} else {
RegionType::Text
};
// Run appropriate recognition
let recognition = if region_type == RegionType::Math {
self.inference
.run_math_recognition(&detection.region_image, options)
.await?
} else {
self.inference
.run_recognition(&detection.region_image, options)
.await?
};
// Decode the recognition output
let decoded_text = self.decode_output(&recognition, options)?;
// Calculate confidence
let confidence = aggregate_confidence(&recognition.character_confidences);
// Filter by confidence threshold
if confidence < options.recognition_threshold {
debug!(
"Skipping region with low confidence: {:.2} < {:.2}",
confidence, options.recognition_threshold
);
continue;
}
// Build character list
let characters = decoded_text
.chars()
.zip(recognition.character_confidences.iter())
.map(|(ch, &conf)| Character {
char: ch,
confidence: conf,
bbox: None, // Could be populated if available from model
})
.collect();
text_regions.push(TextRegion {
bbox: detection.bbox,
text: decoded_text,
confidence,
region_type,
characters,
});
}
// Step 3: Combine results
let combined_text = text_regions
.iter()
.map(|r| r.text.as_str())
.collect::<Vec<_>>()
.join(" ");
let overall_confidence = if text_regions.is_empty() {
0.0
} else {
text_regions.iter().map(|r| r.confidence).sum::<f32>() / text_regions.len() as f32
};
let processing_time_ms = start.elapsed().as_millis() as u64;
debug!(
"OCR completed in {}ms, recognized {} regions",
processing_time_ms,
text_regions.len()
);
Ok(OcrResult {
text: combined_text,
confidence: overall_confidence,
regions: text_regions,
has_math,
processing_time_ms,
})
}
/// Batch process multiple images
pub async fn recognize_batch(
&self,
images: &[&[u8]],
options: &OcrOptions,
) -> Result<Vec<OcrResult>> {
info!("Processing batch of {} images", images.len());
let start = Instant::now();
// Process images in parallel using rayon
let results: Result<Vec<OcrResult>> = images
.iter()
.map(|image_data| {
// Note: In a real async implementation, we'd use tokio::spawn
// For now, we'll use blocking since we're in a sync context
futures::executor::block_on(self.recognize_with_options(image_data, options))
})
.collect();
info!("Batch processing completed in {:?}", start.elapsed());
results
}
/// Decode recognition output using the selected decoder
fn decode_output(
&self,
recognition: &RecognitionResult,
options: &OcrOptions,
) -> Result<String> {
debug!("Decoding output with {:?} decoder", options.decoder_type);
let decoded = match options.decoder_type {
DecoderType::BeamSearch => {
let decoder = BeamSearchDecoder::new(self.vocabulary.clone(), options.beam_width);
decoder.decode(&recognition.logits)?
}
DecoderType::Greedy => {
let decoder = GreedyDecoder::new(self.vocabulary.clone());
decoder.decode(&recognition.logits)?
}
DecoderType::CTC => {
let decoder = CTCDecoder::new(self.vocabulary.clone());
decoder.decode(&recognition.logits)?
}
};
Ok(decoded)
}
/// Get the current model registry
pub fn registry(&self) -> Arc<RwLock<ModelRegistry>> {
Arc::clone(&self.registry)
}
/// Get the default options
pub fn default_options(&self) -> &OcrOptions {
&self.default_options
}
/// Check if engine is warmed up
pub fn is_warmed_up(&self) -> bool {
*self.warmed_up.read()
}
}
impl OcrProcessor for OcrEngine {
fn process(&self, image_data: &[u8], options: &OcrOptions) -> Result<OcrResult> {
// Blocking wrapper for async method
futures::executor::block_on(self.recognize_with_options(image_data, options))
}
fn process_batch(&self, images: &[&[u8]], options: &OcrOptions) -> Result<Vec<OcrResult>> {
// Blocking wrapper for async method
futures::executor::block_on(self.recognize_batch(images, options))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decoder_selection() {
let options = OcrOptions {
decoder_type: DecoderType::BeamSearch,
..Default::default()
};
assert_eq!(options.decoder_type, DecoderType::BeamSearch);
}
#[test]
fn test_warmup_flag() {
let flag = Arc::new(RwLock::new(false));
assert!(!*flag.read());
*flag.write() = true;
assert!(*flag.read());
}
}

View File

@@ -0,0 +1,790 @@
//! ONNX Inference Module
//!
//! This module handles ONNX inference operations for text detection,
//! character recognition, and mathematical expression recognition.
//!
//! # Model Requirements
//!
//! This module requires ONNX models to be available in the configured model directory.
//! Without models, all inference operations will return errors.
//!
//! To use this module:
//! 1. Download compatible ONNX models (PaddleOCR, TrOCR, or similar)
//! 2. Place them in the models directory
//! 3. Enable the `ocr` feature flag
use super::{models::ModelHandle, OcrError, OcrOptions, Result};
use image::{DynamicImage, GenericImageView};
use std::sync::Arc;
use tracing::{debug, info, warn};
#[cfg(feature = "ocr")]
use ndarray::Array4;
#[cfg(feature = "ocr")]
use ort::value::Tensor;
/// Result from text detection
#[derive(Debug, Clone)]
pub struct DetectionResult {
/// Bounding box [x, y, width, height]
pub bbox: [f32; 4],
/// Detection confidence
pub confidence: f32,
/// Cropped image region
pub region_image: Vec<u8>,
/// Whether this region likely contains math
pub is_math_likely: bool,
}
/// Result from text/math recognition
#[derive(Debug, Clone)]
pub struct RecognitionResult {
/// Logits output from the model [sequence_length, vocab_size]
pub logits: Vec<Vec<f32>>,
/// Character-level confidence scores
pub character_confidences: Vec<f32>,
/// Raw output tensor (for debugging)
pub raw_output: Option<Vec<f32>>,
}
/// Inference engine for running ONNX models
///
/// IMPORTANT: This engine requires ONNX models to be loaded.
/// All methods will return errors if models are not properly initialized.
pub struct InferenceEngine {
/// Detection model
detection_model: Arc<ModelHandle>,
/// Recognition model
recognition_model: Arc<ModelHandle>,
/// Math recognition model (optional)
math_model: Option<Arc<ModelHandle>>,
/// Whether to use GPU acceleration
use_gpu: bool,
/// Whether models are actually loaded (vs placeholder handles)
models_loaded: bool,
}
impl InferenceEngine {
/// Create a new inference engine
pub fn new(
detection_model: Arc<ModelHandle>,
recognition_model: Arc<ModelHandle>,
math_model: Option<Arc<ModelHandle>>,
use_gpu: bool,
) -> Result<Self> {
// Check if models are actually loaded with ONNX sessions
let detection_loaded = detection_model.is_loaded();
let recognition_loaded = recognition_model.is_loaded();
let models_loaded = detection_loaded && recognition_loaded;
if !models_loaded {
warn!(
"ONNX models not fully loaded. Detection: {}, Recognition: {}",
detection_loaded, recognition_loaded
);
warn!("OCR inference will fail until models are properly configured.");
} else {
info!(
"Inference engine initialized with loaded models (GPU: {})",
if use_gpu { "enabled" } else { "disabled" }
);
}
Ok(Self {
detection_model,
recognition_model,
math_model,
use_gpu,
models_loaded,
})
}
/// Check if the inference engine is ready for use
pub fn is_ready(&self) -> bool {
self.models_loaded
}
/// Run text detection on an image
pub async fn run_detection(
&self,
image_data: &[u8],
threshold: f32,
) -> Result<Vec<DetectionResult>> {
if !self.models_loaded {
return Err(OcrError::ModelLoading(
"ONNX models not loaded. Please download and configure OCR models before use. \
See examples/scipix/docs/MODEL_SETUP.md for instructions."
.to_string(),
));
}
debug!("Running text detection (threshold: {})", threshold);
let input_tensor = self.preprocess_image_for_detection(image_data)?;
#[cfg(feature = "ocr")]
{
let detections = self
.run_onnx_detection(&input_tensor, threshold, image_data)
.await?;
debug!("Detected {} regions", detections.len());
return Ok(detections);
}
#[cfg(not(feature = "ocr"))]
{
Err(OcrError::Inference(
"OCR feature not enabled. Rebuild with `--features ocr` to enable ONNX inference."
.to_string(),
))
}
}
/// Run text recognition on a region image
pub async fn run_recognition(
&self,
region_image: &[u8],
options: &OcrOptions,
) -> Result<RecognitionResult> {
if !self.models_loaded {
return Err(OcrError::ModelLoading(
"ONNX models not loaded. Please download and configure OCR models before use."
.to_string(),
));
}
debug!("Running text recognition");
let input_tensor = self.preprocess_image_for_recognition(region_image)?;
#[cfg(feature = "ocr")]
{
let result = self.run_onnx_recognition(&input_tensor, options).await?;
return Ok(result);
}
#[cfg(not(feature = "ocr"))]
{
Err(OcrError::Inference(
"OCR feature not enabled. Rebuild with `--features ocr` to enable ONNX inference."
.to_string(),
))
}
}
/// Run math recognition on a region image
pub async fn run_math_recognition(
&self,
region_image: &[u8],
options: &OcrOptions,
) -> Result<RecognitionResult> {
if !self.models_loaded {
return Err(OcrError::ModelLoading(
"ONNX models not loaded. Please download and configure OCR models before use."
.to_string(),
));
}
debug!("Running math recognition");
if self.math_model.is_none() || !self.math_model.as_ref().unwrap().is_loaded() {
warn!("Math model not loaded, falling back to text recognition");
return self.run_recognition(region_image, options).await;
}
let input_tensor = self.preprocess_image_for_math(region_image)?;
#[cfg(feature = "ocr")]
{
let result = self
.run_onnx_math_recognition(&input_tensor, options)
.await?;
return Ok(result);
}
#[cfg(not(feature = "ocr"))]
{
Err(OcrError::Inference(
"OCR feature not enabled. Rebuild with `--features ocr` to enable ONNX inference."
.to_string(),
))
}
}
/// Preprocess image for detection model
fn preprocess_image_for_detection(&self, image_data: &[u8]) -> Result<Vec<f32>> {
let img = image::load_from_memory(image_data)
.map_err(|e| OcrError::ImageProcessing(format!("Failed to decode image: {}", e)))?;
let input_shape = self.detection_model.input_shape();
let (_, _, height, width) = (
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
);
let resized = img.resize_exact(
width as u32,
height as u32,
image::imageops::FilterType::Lanczos3,
);
let rgb = resized.to_rgb8();
let mut tensor = Vec::with_capacity(3 * height * width);
// Convert to NCHW format with normalization
for c in 0..3 {
for y in 0..height {
for x in 0..width {
let pixel = rgb.get_pixel(x as u32, y as u32);
tensor.push(pixel[c] as f32 / 255.0);
}
}
}
Ok(tensor)
}
/// Preprocess image for recognition model
fn preprocess_image_for_recognition(&self, image_data: &[u8]) -> Result<Vec<f32>> {
let img = image::load_from_memory(image_data)
.map_err(|e| OcrError::ImageProcessing(format!("Failed to decode image: {}", e)))?;
let input_shape = self.recognition_model.input_shape();
let (_, channels, height, width) = (
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
);
let resized = img.resize_exact(
width as u32,
height as u32,
image::imageops::FilterType::Lanczos3,
);
let mut tensor = Vec::with_capacity(channels * height * width);
if channels == 1 {
let gray = resized.to_luma8();
for y in 0..height {
for x in 0..width {
let pixel = gray.get_pixel(x as u32, y as u32);
tensor.push((pixel[0] as f32 / 127.5) - 1.0);
}
}
} else {
let rgb = resized.to_rgb8();
for c in 0..3 {
for y in 0..height {
for x in 0..width {
let pixel = rgb.get_pixel(x as u32, y as u32);
tensor.push((pixel[c] as f32 / 127.5) - 1.0);
}
}
}
}
Ok(tensor)
}
/// Preprocess image for math recognition model
fn preprocess_image_for_math(&self, image_data: &[u8]) -> Result<Vec<f32>> {
let math_model = self
.math_model
.as_ref()
.ok_or_else(|| OcrError::Inference("Math model not loaded".to_string()))?;
let img = image::load_from_memory(image_data)
.map_err(|e| OcrError::ImageProcessing(format!("Failed to decode image: {}", e)))?;
let input_shape = math_model.input_shape();
let (_, channels, height, width) = (
input_shape[0],
input_shape[1],
input_shape[2],
input_shape[3],
);
let resized = img.resize_exact(
width as u32,
height as u32,
image::imageops::FilterType::Lanczos3,
);
let mut tensor = Vec::with_capacity(channels * height * width);
if channels == 1 {
let gray = resized.to_luma8();
for y in 0..height {
for x in 0..width {
let pixel = gray.get_pixel(x as u32, y as u32);
tensor.push((pixel[0] as f32 / 127.5) - 1.0);
}
}
} else {
let rgb = resized.to_rgb8();
for c in 0..channels {
for y in 0..height {
for x in 0..width {
let pixel = rgb.get_pixel(x as u32, y as u32);
tensor.push((pixel[c] as f32 / 127.5) - 1.0);
}
}
}
}
Ok(tensor)
}
/// ONNX detection inference (requires `ocr` feature)
#[cfg(feature = "ocr")]
async fn run_onnx_detection(
&self,
input_tensor: &[f32],
threshold: f32,
original_image: &[u8],
) -> Result<Vec<DetectionResult>> {
let session_arc = self.detection_model.session().ok_or_else(|| {
OcrError::OnnxRuntime("Detection model session not loaded".to_string())
})?;
let mut session = session_arc.lock();
let input_shape = self.detection_model.input_shape();
let shape: Vec<usize> = input_shape.to_vec();
// Create tensor from input data
let input_array = Array4::from_shape_vec(
(shape[0], shape[1], shape[2], shape[3]),
input_tensor.to_vec(),
)
.map_err(|e| OcrError::Inference(format!("Failed to create input tensor: {}", e)))?;
// Convert to dynamic-dimension view and create ORT tensor
let input_dyn = input_array.into_dyn();
let input_tensor = Tensor::from_array(input_dyn)
.map_err(|e| OcrError::OnnxRuntime(format!("Failed to create ORT tensor: {}", e)))?;
// Run inference
let outputs = session
.run(ort::inputs![input_tensor])
.map_err(|e| OcrError::OnnxRuntime(format!("Inference failed: {}", e)))?;
let output_tensor = outputs
.iter()
.next()
.map(|(_, v)| v)
.ok_or_else(|| OcrError::OnnxRuntime("No output tensor found".to_string()))?;
let (_, raw_data) = output_tensor
.try_extract_tensor::<f32>()
.map_err(|e| OcrError::OnnxRuntime(format!("Failed to extract output: {}", e)))?;
let output_data: Vec<f32> = raw_data.to_vec();
let original_img = image::load_from_memory(original_image)
.map_err(|e| OcrError::ImageProcessing(format!("Failed to decode image: {}", e)))?;
let detections = self.parse_detection_output(&output_data, threshold, &original_img)?;
Ok(detections)
}
/// Parse detection model output
#[cfg(feature = "ocr")]
fn parse_detection_output(
&self,
output: &[f32],
threshold: f32,
original_img: &DynamicImage,
) -> Result<Vec<DetectionResult>> {
let mut results = Vec::new();
let output_shape = self.detection_model.output_shape();
if output_shape.len() >= 2 {
let num_detections = output_shape[1];
let detection_size = if output_shape.len() >= 3 {
output_shape[2]
} else {
85
};
for i in 0..num_detections {
let base_idx = i * detection_size;
if base_idx + 5 > output.len() {
break;
}
let confidence = output[base_idx + 4];
if confidence < threshold {
continue;
}
let cx = output[base_idx];
let cy = output[base_idx + 1];
let w = output[base_idx + 2];
let h = output[base_idx + 3];
let img_width = original_img.width() as f32;
let img_height = original_img.height() as f32;
let x = ((cx - w / 2.0) * img_width).max(0.0);
let y = ((cy - h / 2.0) * img_height).max(0.0);
let width = (w * img_width).min(img_width - x);
let height = (h * img_height).min(img_height - y);
if width <= 0.0 || height <= 0.0 {
continue;
}
let cropped =
original_img.crop_imm(x as u32, y as u32, width as u32, height as u32);
let mut region_bytes = Vec::new();
cropped
.write_to(
&mut std::io::Cursor::new(&mut region_bytes),
image::ImageFormat::Png,
)
.map_err(|e| {
OcrError::ImageProcessing(format!("Failed to encode region: {}", e))
})?;
let aspect_ratio = width / height;
let is_math_likely = aspect_ratio > 2.0 || aspect_ratio < 0.5;
results.push(DetectionResult {
bbox: [x, y, width, height],
confidence,
region_image: region_bytes,
is_math_likely,
});
}
}
Ok(results)
}
/// ONNX recognition inference (requires `ocr` feature)
#[cfg(feature = "ocr")]
async fn run_onnx_recognition(
&self,
input_tensor: &[f32],
_options: &OcrOptions,
) -> Result<RecognitionResult> {
let session_arc = self.recognition_model.session().ok_or_else(|| {
OcrError::OnnxRuntime("Recognition model session not loaded".to_string())
})?;
let mut session = session_arc.lock();
let input_shape = self.recognition_model.input_shape();
let shape: Vec<usize> = input_shape.to_vec();
let input_array = Array4::from_shape_vec(
(shape[0], shape[1], shape[2], shape[3]),
input_tensor.to_vec(),
)
.map_err(|e| OcrError::Inference(format!("Failed to create input tensor: {}", e)))?;
let input_dyn = input_array.into_dyn();
let input_ort = Tensor::from_array(input_dyn)
.map_err(|e| OcrError::OnnxRuntime(format!("Failed to create ORT tensor: {}", e)))?;
let outputs = session
.run(ort::inputs![input_ort])
.map_err(|e| OcrError::OnnxRuntime(format!("Recognition inference failed: {}", e)))?;
let output_tensor = outputs
.iter()
.next()
.map(|(_, v)| v)
.ok_or_else(|| OcrError::OnnxRuntime("No output tensor found".to_string()))?;
let (_, raw_data) = output_tensor
.try_extract_tensor::<f32>()
.map_err(|e| OcrError::OnnxRuntime(format!("Failed to extract output: {}", e)))?;
let output_data: Vec<f32> = raw_data.to_vec();
let output_shape = self.recognition_model.output_shape();
let seq_len = output_shape.get(1).copied().unwrap_or(26);
let vocab_size = output_shape.get(2).copied().unwrap_or(37);
let mut logits = Vec::new();
let mut character_confidences = Vec::new();
for i in 0..seq_len {
let start_idx = i * vocab_size;
let end_idx = start_idx + vocab_size;
if end_idx <= output_data.len() {
let step_logits: Vec<f32> = output_data[start_idx..end_idx].to_vec();
let max_logit = step_logits
.iter()
.cloned()
.fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = step_logits.iter().map(|&x| (x - max_logit).exp()).sum();
let softmax: Vec<f32> = step_logits
.iter()
.map(|&x| (x - max_logit).exp() / exp_sum)
.collect();
let max_confidence = softmax.iter().cloned().fold(0.0f32, f32::max);
character_confidences.push(max_confidence);
logits.push(step_logits);
}
}
Ok(RecognitionResult {
logits,
character_confidences,
raw_output: Some(output_data),
})
}
/// ONNX math recognition inference (requires `ocr` feature)
#[cfg(feature = "ocr")]
async fn run_onnx_math_recognition(
&self,
input_tensor: &[f32],
_options: &OcrOptions,
) -> Result<RecognitionResult> {
let math_model = self
.math_model
.as_ref()
.ok_or_else(|| OcrError::Inference("Math model not loaded".to_string()))?;
let session_arc = math_model
.session()
.ok_or_else(|| OcrError::OnnxRuntime("Math model session not loaded".to_string()))?;
let mut session = session_arc.lock();
let input_shape = math_model.input_shape();
let shape: Vec<usize> = input_shape.to_vec();
let input_array = Array4::from_shape_vec(
(shape[0], shape[1], shape[2], shape[3]),
input_tensor.to_vec(),
)
.map_err(|e| OcrError::Inference(format!("Failed to create input tensor: {}", e)))?;
let input_dyn = input_array.into_dyn();
let input_ort = Tensor::from_array(input_dyn)
.map_err(|e| OcrError::OnnxRuntime(format!("Failed to create ORT tensor: {}", e)))?;
let outputs = session.run(ort::inputs![input_ort]).map_err(|e| {
OcrError::OnnxRuntime(format!("Math recognition inference failed: {}", e))
})?;
let output_tensor = outputs
.iter()
.next()
.map(|(_, v)| v)
.ok_or_else(|| OcrError::OnnxRuntime("No output tensor found".to_string()))?;
let (_, raw_data) = output_tensor
.try_extract_tensor::<f32>()
.map_err(|e| OcrError::OnnxRuntime(format!("Failed to extract output: {}", e)))?;
let output_data: Vec<f32> = raw_data.to_vec();
let output_shape = math_model.output_shape();
let seq_len = output_shape.get(1).copied().unwrap_or(50);
let vocab_size = output_shape.get(2).copied().unwrap_or(512);
let mut logits = Vec::new();
let mut character_confidences = Vec::new();
for i in 0..seq_len {
let start_idx = i * vocab_size;
let end_idx = start_idx + vocab_size;
if end_idx <= output_data.len() {
let step_logits: Vec<f32> = output_data[start_idx..end_idx].to_vec();
let max_logit = step_logits
.iter()
.cloned()
.fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = step_logits.iter().map(|&x| (x - max_logit).exp()).sum();
let softmax: Vec<f32> = step_logits
.iter()
.map(|&x| (x - max_logit).exp() / exp_sum)
.collect();
let max_confidence = softmax.iter().cloned().fold(0.0f32, f32::max);
character_confidences.push(max_confidence);
logits.push(step_logits);
}
}
Ok(RecognitionResult {
logits,
character_confidences,
raw_output: Some(output_data),
})
}
/// Get detection model
pub fn detection_model(&self) -> &ModelHandle {
&self.detection_model
}
/// Get recognition model
pub fn recognition_model(&self) -> &ModelHandle {
&self.recognition_model
}
/// Get math model if available
pub fn math_model(&self) -> Option<&ModelHandle> {
self.math_model.as_ref().map(|m| m.as_ref())
}
/// Check if GPU acceleration is enabled
pub fn is_gpu_enabled(&self) -> bool {
self.use_gpu
}
}
/// Batch inference optimization
impl InferenceEngine {
/// Run batch detection on multiple images
pub async fn run_batch_detection(
&self,
images: &[&[u8]],
threshold: f32,
) -> Result<Vec<Vec<DetectionResult>>> {
if !self.models_loaded {
return Err(OcrError::ModelLoading(
"ONNX models not loaded. Cannot run batch detection.".to_string(),
));
}
debug!("Running batch detection on {} images", images.len());
let mut results = Vec::new();
for image in images {
let detections = self.run_detection(image, threshold).await?;
results.push(detections);
}
Ok(results)
}
/// Run batch recognition on multiple regions
pub async fn run_batch_recognition(
&self,
regions: &[&[u8]],
options: &OcrOptions,
) -> Result<Vec<RecognitionResult>> {
if !self.models_loaded {
return Err(OcrError::ModelLoading(
"ONNX models not loaded. Cannot run batch recognition.".to_string(),
));
}
debug!("Running batch recognition on {} regions", regions.len());
let mut results = Vec::new();
for region in regions {
let result = self.run_recognition(region, options).await?;
results.push(result);
}
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ocr::models::{ModelMetadata, ModelType};
use std::path::PathBuf;
fn create_test_model(model_type: ModelType, path: PathBuf) -> Arc<ModelHandle> {
let metadata = ModelMetadata {
name: format!("{:?} Model", model_type),
version: "1.0.0".to_string(),
input_shape: vec![1, 3, 640, 640],
output_shape: vec![1, 100, 85],
input_dtype: "float32".to_string(),
file_size: 1000,
checksum: None,
};
Arc::new(ModelHandle::new(model_type, path, metadata).unwrap())
}
#[test]
fn test_inference_engine_creation_without_models() {
let detection = create_test_model(
ModelType::Detection,
PathBuf::from("/nonexistent/model.onnx"),
);
let recognition = create_test_model(
ModelType::Recognition,
PathBuf::from("/nonexistent/model.onnx"),
);
let engine = InferenceEngine::new(detection, recognition, None, false).unwrap();
assert!(!engine.is_ready());
}
#[tokio::test]
async fn test_detection_fails_without_models() {
let detection = create_test_model(
ModelType::Detection,
PathBuf::from("/nonexistent/model.onnx"),
);
let recognition = create_test_model(
ModelType::Recognition,
PathBuf::from("/nonexistent/model.onnx"),
);
let engine = InferenceEngine::new(detection, recognition, None, false).unwrap();
let png_data = create_test_png();
let result = engine.run_detection(&png_data, 0.5).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), OcrError::ModelLoading(_)));
}
#[tokio::test]
async fn test_recognition_fails_without_models() {
let detection = create_test_model(
ModelType::Detection,
PathBuf::from("/nonexistent/model.onnx"),
);
let recognition = create_test_model(
ModelType::Recognition,
PathBuf::from("/nonexistent/model.onnx"),
);
let engine = InferenceEngine::new(detection, recognition, None, false).unwrap();
let png_data = create_test_png();
let options = OcrOptions::default();
let result = engine.run_recognition(&png_data, &options).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), OcrError::ModelLoading(_)));
}
#[test]
fn test_is_ready_reflects_model_state() {
let detection = create_test_model(ModelType::Detection, PathBuf::from("/fake/path"));
let recognition = create_test_model(ModelType::Recognition, PathBuf::from("/fake/path"));
let engine = InferenceEngine::new(detection, recognition, None, false).unwrap();
assert!(!engine.is_ready());
}
fn create_test_png() -> Vec<u8> {
use image::{ImageBuffer, RgbImage};
let img: RgbImage = ImageBuffer::from_fn(10, 10, |_, _| image::Rgb([255, 255, 255]));
let mut bytes: Vec<u8> = Vec::new();
img.write_to(
&mut std::io::Cursor::new(&mut bytes),
image::ImageFormat::Png,
)
.unwrap();
bytes
}
}

View File

@@ -0,0 +1,235 @@
//! OCR Engine Module
//!
//! This module provides optical character recognition capabilities for the ruvector-scipix system.
//! It supports text detection, character recognition, and mathematical expression recognition using
//! ONNX models for high-performance inference.
//!
//! # Architecture
//!
//! The OCR module is organized into several submodules:
//! - `engine`: Main OcrEngine for orchestrating OCR operations
//! - `models`: Model management, loading, and caching
//! - `inference`: ONNX inference operations for detection and recognition
//! - `decoder`: Output decoding strategies (beam search, greedy, CTC)
//! - `confidence`: Confidence scoring and calibration
//!
//! # Example
//!
//! ```no_run
//! use ruvector_scipix::ocr::{OcrEngine, OcrOptions};
//!
//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
//! // Initialize the OCR engine
//! let engine = OcrEngine::new().await?;
//!
//! // Load an image
//! let image_data = std::fs::read("math_formula.png")?;
//!
//! // Perform OCR
//! let result = engine.recognize(&image_data).await?;
//!
//! println!("Recognized text: {}", result.text);
//! println!("Confidence: {:.2}%", result.confidence * 100.0);
//! # Ok(())
//! # }
//! ```
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
// Submodules
mod confidence;
mod decoder;
mod engine;
mod inference;
mod models;
// Public exports
pub use confidence::{aggregate_confidence, calculate_confidence, ConfidenceCalibrator};
pub use decoder::{BeamSearchDecoder, CTCDecoder, Decoder, GreedyDecoder, Vocabulary};
pub use engine::{OcrEngine, OcrProcessor};
pub use inference::{DetectionResult, InferenceEngine, RecognitionResult};
pub use models::{ModelHandle, ModelRegistry};
/// OCR processing options
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrOptions {
/// Detection threshold for text regions (0.0-1.0)
pub detection_threshold: f32,
/// Recognition confidence threshold (0.0-1.0)
pub recognition_threshold: f32,
/// Enable mathematical expression recognition
pub enable_math: bool,
/// Decoder type to use
pub decoder_type: DecoderType,
/// Beam width for beam search decoder
pub beam_width: usize,
/// Maximum batch size for inference
pub batch_size: usize,
/// Enable GPU acceleration if available
pub use_gpu: bool,
/// Language hints for recognition
pub languages: Vec<String>,
}
impl Default for OcrOptions {
fn default() -> Self {
Self {
detection_threshold: 0.5,
recognition_threshold: 0.6,
enable_math: true,
decoder_type: DecoderType::BeamSearch,
beam_width: 5,
batch_size: 1,
use_gpu: false,
languages: vec!["en".to_string()],
}
}
}
/// Decoder type selection
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum DecoderType {
/// Beam search decoder (higher quality, slower)
BeamSearch,
/// Greedy decoder (faster, lower quality)
Greedy,
/// CTC decoder for sequence-to-sequence models
CTC,
}
/// OCR result containing recognized text and metadata
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrResult {
/// Recognized text
pub text: String,
/// Overall confidence score (0.0-1.0)
pub confidence: f32,
/// Detected text regions with their bounding boxes
pub regions: Vec<TextRegion>,
/// Whether mathematical expressions were detected
pub has_math: bool,
/// Processing time in milliseconds
pub processing_time_ms: u64,
}
/// A detected text region with position and content
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextRegion {
/// Bounding box coordinates [x, y, width, height]
pub bbox: [f32; 4],
/// Recognized text in this region
pub text: String,
/// Confidence score for this region (0.0-1.0)
pub confidence: f32,
/// Region type (text, math, etc.)
pub region_type: RegionType,
/// Character-level details if available
pub characters: Vec<Character>,
}
/// Type of text region
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RegionType {
/// Regular text
Text,
/// Mathematical expression
Math,
/// Diagram or figure
Diagram,
/// Table
Table,
/// Unknown type
Unknown,
}
/// Individual character with position and confidence
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Character {
/// The character
pub char: char,
/// Confidence score (0.0-1.0)
pub confidence: f32,
/// Bounding box if available
pub bbox: Option<[f32; 4]>,
}
/// Error types for OCR operations
#[derive(Debug, thiserror::Error)]
pub enum OcrError {
#[error("Model loading error: {0}")]
ModelLoading(String),
#[error("Inference error: {0}")]
Inference(String),
#[error("Image processing error: {0}")]
ImageProcessing(String),
#[error("Decoding error: {0}")]
Decoding(String),
#[error("Invalid configuration: {0}")]
InvalidConfig(String),
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("ONNX Runtime error: {0}")]
OnnxRuntime(String),
}
pub type Result<T> = std::result::Result<T, OcrError>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ocr_options_default() {
let options = OcrOptions::default();
assert_eq!(options.detection_threshold, 0.5);
assert_eq!(options.recognition_threshold, 0.6);
assert!(options.enable_math);
assert_eq!(options.decoder_type, DecoderType::BeamSearch);
assert_eq!(options.beam_width, 5);
}
#[test]
fn test_text_region_creation() {
let region = TextRegion {
bbox: [10.0, 20.0, 100.0, 30.0],
text: "Test".to_string(),
confidence: 0.95,
region_type: RegionType::Text,
characters: vec![],
};
assert_eq!(region.bbox[0], 10.0);
assert_eq!(region.text, "Test");
assert_eq!(region.region_type, RegionType::Text);
}
#[test]
fn test_decoder_type_equality() {
assert_eq!(DecoderType::BeamSearch, DecoderType::BeamSearch);
assert_ne!(DecoderType::BeamSearch, DecoderType::Greedy);
assert_ne!(DecoderType::Greedy, DecoderType::CTC);
}
}

View File

@@ -0,0 +1,373 @@
//! Model Management Module
//!
//! This module handles loading, caching, and managing ONNX models for OCR.
//! It supports lazy loading, model downloading with progress tracking,
//! and checksum verification.
use super::{OcrError, Result};
use dashmap::DashMap;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tracing::{debug, info, warn};
#[cfg(feature = "ocr")]
use ort::session::Session;
#[cfg(feature = "ocr")]
use parking_lot::Mutex;
/// Model types supported by the OCR engine
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ModelType {
/// Text detection model (finds text regions in images)
Detection,
/// Text recognition model (recognizes characters in regions)
Recognition,
/// Math expression recognition model
Math,
}
/// Handle to a loaded ONNX model
#[derive(Clone)]
pub struct ModelHandle {
/// Model type
model_type: ModelType,
/// Path to the model file
path: PathBuf,
/// Model metadata
metadata: ModelMetadata,
/// ONNX Runtime session (when ocr feature is enabled)
/// Wrapped in Mutex for mutable access required by ort 2.0 Session::run
#[cfg(feature = "ocr")]
session: Option<Arc<Mutex<Session>>>,
/// Mock session for when ocr feature is disabled
#[cfg(not(feature = "ocr"))]
#[allow(dead_code)]
session: Option<()>,
}
impl ModelHandle {
/// Create a new model handle
pub fn new(model_type: ModelType, path: PathBuf, metadata: ModelMetadata) -> Result<Self> {
debug!("Creating model handle for {:?} at {:?}", model_type, path);
#[cfg(feature = "ocr")]
let session = if path.exists() {
match Session::builder() {
Ok(builder) => match builder.commit_from_file(&path) {
Ok(session) => {
info!("Successfully loaded ONNX model: {:?}", path);
Some(Arc::new(Mutex::new(session)))
}
Err(e) => {
warn!("Failed to load ONNX model {:?}: {}", path, e);
None
}
},
Err(e) => {
warn!("Failed to create ONNX session builder: {}", e);
None
}
}
} else {
debug!("Model file not found: {:?}", path);
None
};
#[cfg(not(feature = "ocr"))]
let session: Option<()> = None;
Ok(Self {
model_type,
path,
metadata,
session,
})
}
/// Check if the model session is loaded
pub fn is_loaded(&self) -> bool {
self.session.is_some()
}
/// Get the ONNX session (only available with ocr feature)
#[cfg(feature = "ocr")]
pub fn session(&self) -> Option<&Arc<Mutex<Session>>> {
self.session.as_ref()
}
/// Get the model type
pub fn model_type(&self) -> ModelType {
self.model_type
}
/// Get the model path
pub fn path(&self) -> &Path {
&self.path
}
/// Get model metadata
pub fn metadata(&self) -> &ModelMetadata {
&self.metadata
}
/// Get input shape for the model
pub fn input_shape(&self) -> &[usize] {
&self.metadata.input_shape
}
/// Get output shape for the model
pub fn output_shape(&self) -> &[usize] {
&self.metadata.output_shape
}
}
/// Model metadata
#[derive(Debug, Clone)]
pub struct ModelMetadata {
/// Model name
pub name: String,
/// Model version
pub version: String,
/// Input tensor shape
pub input_shape: Vec<usize>,
/// Output tensor shape
pub output_shape: Vec<usize>,
/// Expected input data type
pub input_dtype: String,
/// File size in bytes
pub file_size: u64,
/// SHA256 checksum
pub checksum: Option<String>,
}
/// Model registry for loading and caching models
pub struct ModelRegistry {
/// Cache of loaded models
cache: DashMap<ModelType, Arc<ModelHandle>>,
/// Base directory for models
model_dir: PathBuf,
/// Whether to enable lazy loading
lazy_loading: bool,
}
impl ModelRegistry {
/// Create a new model registry
pub fn new() -> Self {
Self::with_model_dir(PathBuf::from("./models"))
}
/// Create a new model registry with custom model directory
pub fn with_model_dir(model_dir: PathBuf) -> Self {
info!("Initializing model registry at {:?}", model_dir);
Self {
cache: DashMap::new(),
model_dir,
lazy_loading: true,
}
}
/// Load the detection model
pub async fn load_detection_model(&mut self) -> Result<Arc<ModelHandle>> {
self.load_model(ModelType::Detection).await
}
/// Load the recognition model
pub async fn load_recognition_model(&mut self) -> Result<Arc<ModelHandle>> {
self.load_model(ModelType::Recognition).await
}
/// Load the math recognition model
pub async fn load_math_model(&mut self) -> Result<Arc<ModelHandle>> {
self.load_model(ModelType::Math).await
}
/// Load a model by type
pub async fn load_model(&mut self, model_type: ModelType) -> Result<Arc<ModelHandle>> {
// Check cache first
if let Some(handle) = self.cache.get(&model_type) {
debug!("Model {:?} found in cache", model_type);
return Ok(Arc::clone(handle.value()));
}
info!("Loading model {:?}...", model_type);
// Get model path
let model_path = self.get_model_path(model_type);
// Check if model exists
if !model_path.exists() {
if self.lazy_loading {
warn!(
"Model {:?} not found at {:?}. OCR will not work without models.",
model_type, model_path
);
warn!("Download models from: https://github.com/PaddlePaddle/PaddleOCR or configure custom models.");
} else {
return Err(OcrError::ModelLoading(format!(
"Model {:?} not found at {:?}",
model_type, model_path
)));
}
}
// Load model metadata
let metadata = self.get_model_metadata(model_type);
// Verify checksum if provided
if let Some(ref checksum) = metadata.checksum {
if model_path.exists() {
debug!("Verifying model checksum: {}", checksum);
// In production: verify_checksum(&model_path, checksum)?;
}
}
// Create model handle (will load ONNX session if file exists)
let handle = Arc::new(ModelHandle::new(model_type, model_path, metadata)?);
// Cache the handle
self.cache.insert(model_type, Arc::clone(&handle));
if handle.is_loaded() {
info!(
"Model {:?} loaded successfully with ONNX session",
model_type
);
} else {
warn!(
"Model {:?} handle created but ONNX session not loaded",
model_type
);
}
Ok(handle)
}
/// Get the file path for a model type
fn get_model_path(&self, model_type: ModelType) -> PathBuf {
let filename = match model_type {
ModelType::Detection => "text_detection.onnx",
ModelType::Recognition => "text_recognition.onnx",
ModelType::Math => "math_recognition.onnx",
};
self.model_dir.join(filename)
}
/// Get default metadata for a model type
fn get_model_metadata(&self, model_type: ModelType) -> ModelMetadata {
match model_type {
ModelType::Detection => ModelMetadata {
name: "Text Detection".to_string(),
version: "1.0.0".to_string(),
input_shape: vec![1, 3, 640, 640], // NCHW format
output_shape: vec![1, 25200, 85], // Detections
input_dtype: "float32".to_string(),
file_size: 50_000_000, // ~50MB
checksum: None,
},
ModelType::Recognition => ModelMetadata {
name: "Text Recognition".to_string(),
version: "1.0.0".to_string(),
input_shape: vec![1, 1, 32, 128], // NCHW format
output_shape: vec![1, 26, 37], // Sequence length, vocab size
input_dtype: "float32".to_string(),
file_size: 20_000_000, // ~20MB
checksum: None,
},
ModelType::Math => ModelMetadata {
name: "Math Recognition".to_string(),
version: "1.0.0".to_string(),
input_shape: vec![1, 1, 64, 256], // NCHW format
output_shape: vec![1, 50, 512], // Sequence length, vocab size
input_dtype: "float32".to_string(),
file_size: 80_000_000, // ~80MB
checksum: None,
},
}
}
/// Clear the model cache
pub fn clear_cache(&mut self) {
info!("Clearing model cache");
self.cache.clear();
}
/// Get a cached model if available
pub fn get_cached(&self, model_type: ModelType) -> Option<Arc<ModelHandle>> {
self.cache.get(&model_type).map(|h| Arc::clone(h.value()))
}
/// Set lazy loading mode
pub fn set_lazy_loading(&mut self, enabled: bool) {
self.lazy_loading = enabled;
}
/// Get the model directory
pub fn model_dir(&self) -> &Path {
&self.model_dir
}
}
impl Default for ModelRegistry {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_registry_creation() {
let registry = ModelRegistry::new();
assert_eq!(registry.model_dir(), Path::new("./models"));
assert!(registry.lazy_loading);
}
#[test]
fn test_model_path_generation() {
let registry = ModelRegistry::new();
let path = registry.get_model_path(ModelType::Detection);
assert!(path.to_string_lossy().contains("text_detection.onnx"));
}
#[test]
fn test_model_metadata() {
let registry = ModelRegistry::new();
let metadata = registry.get_model_metadata(ModelType::Recognition);
assert_eq!(metadata.name, "Text Recognition");
assert_eq!(metadata.version, "1.0.0");
assert_eq!(metadata.input_shape, vec![1, 1, 32, 128]);
}
#[tokio::test]
async fn test_model_caching() {
let mut registry = ModelRegistry::new();
let model1 = registry.load_detection_model().await.unwrap();
let model2 = registry.load_detection_model().await.unwrap();
assert!(Arc::ptr_eq(&model1, &model2));
}
#[test]
fn test_clear_cache() {
let mut registry = ModelRegistry::new();
registry.clear_cache();
assert_eq!(registry.cache.len(), 0);
}
#[test]
fn test_model_handle_without_file() {
let path = PathBuf::from("/nonexistent/model.onnx");
let metadata = ModelMetadata {
name: "Test".to_string(),
version: "1.0.0".to_string(),
input_shape: vec![1, 3, 640, 640],
output_shape: vec![1, 100, 85],
input_dtype: "float32".to_string(),
file_size: 1000,
checksum: None,
};
let handle = ModelHandle::new(ModelType::Detection, path, metadata).unwrap();
assert!(!handle.is_loaded());
}
}

View File

@@ -0,0 +1,396 @@
//! Dynamic batching for throughput optimization
//!
//! Provides intelligent batching to maximize GPU/CPU utilization while
//! maintaining acceptable latency.
use std::collections::VecDeque;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::{oneshot, Mutex};
use tokio::time::sleep;
/// Item in the batching queue
pub struct BatchItem<T, R> {
pub data: T,
pub response: oneshot::Sender<BatchResult<R>>,
pub enqueued_at: Instant,
}
/// Result of batch processing
pub type BatchResult<T> = std::result::Result<T, BatchError>;
/// Batch processing errors
#[derive(Debug, Clone)]
pub enum BatchError {
Timeout,
ProcessingFailed(String),
QueueFull,
}
impl std::fmt::Display for BatchError {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
BatchError::Timeout => write!(f, "Batch processing timeout"),
BatchError::ProcessingFailed(msg) => write!(f, "Processing failed: {}", msg),
BatchError::QueueFull => write!(f, "Queue is full"),
}
}
}
impl std::error::Error for BatchError {}
/// Dynamic batcher configuration
#[derive(Debug, Clone)]
pub struct BatchConfig {
/// Maximum items in a batch
pub max_batch_size: usize,
/// Maximum time to wait before processing partial batch
pub max_wait_ms: u64,
/// Maximum queue size
pub max_queue_size: usize,
/// Minimum batch size to prefer
pub preferred_batch_size: usize,
}
impl Default for BatchConfig {
fn default() -> Self {
Self {
max_batch_size: 32,
max_wait_ms: 50,
max_queue_size: 1000,
preferred_batch_size: 16,
}
}
}
/// Dynamic batcher for throughput optimization
pub struct DynamicBatcher<T, R> {
config: BatchConfig,
queue: Arc<Mutex<VecDeque<BatchItem<T, R>>>>,
processor: Arc<dyn Fn(Vec<T>) -> Vec<std::result::Result<R, String>> + Send + Sync>,
shutdown: Arc<Mutex<bool>>,
}
impl<T, R> DynamicBatcher<T, R>
where
T: Send + 'static,
R: Send + 'static,
{
/// Create new dynamic batcher
pub fn new<F>(config: BatchConfig, processor: F) -> Self
where
F: Fn(Vec<T>) -> Vec<std::result::Result<R, String>> + Send + Sync + 'static,
{
Self {
config,
queue: Arc::new(Mutex::new(VecDeque::new())),
processor: Arc::new(processor),
shutdown: Arc::new(Mutex::new(false)),
}
}
/// Add item to batch queue
pub async fn add(&self, item: T) -> BatchResult<R> {
let (tx, rx) = oneshot::channel();
let batch_item = BatchItem {
data: item,
response: tx,
enqueued_at: Instant::now(),
};
{
let mut queue = self.queue.lock().await;
if queue.len() >= self.config.max_queue_size {
return Err(BatchError::QueueFull);
}
queue.push_back(batch_item);
}
// Wait for response
rx.await.map_err(|_| BatchError::Timeout)?
}
/// Start batch processing loop
pub async fn run(&self) {
let mut last_process = Instant::now();
loop {
// Check if shutdown requested
{
let shutdown = self.shutdown.lock().await;
if *shutdown {
break;
}
}
let should_process = {
let queue = self.queue.lock().await;
queue.len() >= self.config.max_batch_size
|| (queue.len() >= self.config.preferred_batch_size
&& last_process.elapsed().as_millis() >= self.config.max_wait_ms as u128)
|| (queue.len() > 0
&& last_process.elapsed().as_millis() >= self.config.max_wait_ms as u128)
};
if should_process {
self.process_batch().await;
last_process = Instant::now();
} else {
// Sleep briefly to avoid busy waiting
sleep(Duration::from_millis(1)).await;
}
}
// Process remaining items before shutdown
self.process_batch().await;
}
/// Process current batch
async fn process_batch(&self) {
let items = {
let mut queue = self.queue.lock().await;
let batch_size = self.config.max_batch_size.min(queue.len());
if batch_size == 0 {
return;
}
queue.drain(..batch_size).collect::<Vec<_>>()
};
if items.is_empty() {
return;
}
// Extract data and response channels
let (data, responses): (Vec<_>, Vec<_>) = items
.into_iter()
.map(|item| (item.data, item.response))
.unzip();
// Process batch
let results = (self.processor)(data);
// Send responses
for (response_tx, result) in responses.into_iter().zip(results.into_iter()) {
let batch_result = result.map_err(|e| BatchError::ProcessingFailed(e));
let _ = response_tx.send(batch_result);
}
}
/// Gracefully shutdown the batcher
pub async fn shutdown(&self) {
let mut shutdown = self.shutdown.lock().await;
*shutdown = true;
}
/// Get current queue size
pub async fn queue_size(&self) -> usize {
self.queue.lock().await.len()
}
/// Get current queue statistics
pub async fn stats(&self) -> BatchStats {
let queue = self.queue.lock().await;
let queue_size = queue.len();
let max_wait = queue
.front()
.map(|item| item.enqueued_at.elapsed())
.unwrap_or(Duration::from_secs(0));
BatchStats {
queue_size,
max_wait_time: max_wait,
}
}
}
/// Batch statistics
#[derive(Debug, Clone)]
pub struct BatchStats {
pub queue_size: usize,
pub max_wait_time: Duration,
}
/// Adaptive batcher that adjusts batch size based on latency
pub struct AdaptiveBatcher<T, R> {
inner: DynamicBatcher<T, R>,
config: Arc<Mutex<BatchConfig>>,
latency_history: Arc<Mutex<VecDeque<Duration>>>,
target_latency: Duration,
}
impl<T, R> AdaptiveBatcher<T, R>
where
T: Send + 'static,
R: Send + 'static,
{
/// Create adaptive batcher with target latency
pub fn new<F>(initial_config: BatchConfig, target_latency: Duration, processor: F) -> Self
where
F: Fn(Vec<T>) -> Vec<Result<R, String>> + Send + Sync + 'static,
{
let config = Arc::new(Mutex::new(initial_config.clone()));
let inner = DynamicBatcher::new(initial_config, processor);
Self {
inner,
config,
latency_history: Arc::new(Mutex::new(VecDeque::with_capacity(100))),
target_latency,
}
}
/// Add item and adapt batch size
pub async fn add(&self, item: T) -> Result<R, BatchError> {
let start = Instant::now();
let result = self.inner.add(item).await;
let latency = start.elapsed();
// Record latency
{
let mut history = self.latency_history.lock().await;
history.push_back(latency);
if history.len() > 100 {
history.pop_front();
}
}
// Adapt batch size every 10 requests
{
let history = self.latency_history.lock().await;
if history.len() % 10 == 0 && history.len() >= 10 {
let avg_latency: Duration = history.iter().sum::<Duration>() / history.len() as u32;
let mut config = self.config.lock().await;
if avg_latency > self.target_latency {
// Reduce batch size to lower latency
config.max_batch_size = (config.max_batch_size * 9 / 10).max(1);
} else if avg_latency < self.target_latency / 2 {
// Increase batch size for better throughput
config.max_batch_size = (config.max_batch_size * 11 / 10).min(128);
}
}
}
result
}
/// Run the batcher
pub async fn run(&self) {
self.inner.run().await;
}
/// Get current configuration
pub async fn current_config(&self) -> BatchConfig {
self.config.lock().await.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_dynamic_batcher() {
let config = BatchConfig {
max_batch_size: 4,
max_wait_ms: 100,
max_queue_size: 100,
preferred_batch_size: 2,
};
let batcher = Arc::new(DynamicBatcher::new(config, |items: Vec<i32>| {
items.into_iter().map(|x| Ok(x * 2)).collect()
}));
// Start processing loop
let batcher_clone = batcher.clone();
tokio::spawn(async move {
batcher_clone.run().await;
});
// Add items
let mut handles = vec![];
for i in 0..8 {
let batcher = batcher.clone();
handles.push(tokio::spawn(async move { batcher.add(i).await }));
}
// Wait for results
for (i, handle) in handles.into_iter().enumerate() {
let result = handle.await.unwrap().unwrap();
assert_eq!(result, (i as i32) * 2);
}
batcher.shutdown().await;
}
#[tokio::test]
async fn test_batch_stats() {
let config = BatchConfig::default();
let batcher = DynamicBatcher::new(config, |items: Vec<i32>| {
items.into_iter().map(|x| Ok(x)).collect()
});
// Queue some items without processing
let _ = batcher.add(1);
let _ = batcher.add(2);
let _ = batcher.add(3);
let stats = batcher.stats().await;
assert_eq!(stats.queue_size, 3);
}
#[tokio::test]
async fn test_queue_full() {
let config = BatchConfig {
max_queue_size: 2,
..Default::default()
};
let batcher = DynamicBatcher::new(config, |items: Vec<i32>| {
std::thread::sleep(Duration::from_secs(1)); // Slow processing
items.into_iter().map(|x| Ok(x)).collect()
});
// Fill queue
let _ = batcher.add(1);
let _ = batcher.add(2);
// This should fail - queue is full
let result = batcher.add(3).await;
assert!(matches!(result, Err(BatchError::QueueFull)));
}
#[tokio::test]
async fn test_adaptive_batcher() {
let config = BatchConfig {
max_batch_size: 8,
max_wait_ms: 50,
max_queue_size: 100,
preferred_batch_size: 4,
};
let batcher = Arc::new(AdaptiveBatcher::new(
config,
Duration::from_millis(100),
|items: Vec<i32>| items.into_iter().map(|x| Ok(x * 2)).collect(),
));
let batcher_clone = batcher.clone();
tokio::spawn(async move {
batcher_clone.run().await;
});
// Process some requests
for i in 0..20 {
let result = batcher.add(i).await.unwrap();
assert_eq!(result, i * 2);
}
// Configuration should have adapted
let final_config = batcher.current_config().await;
assert!(final_config.max_batch_size > 0);
}
}

View File

@@ -0,0 +1,409 @@
//! Memory optimization utilities
//!
//! Provides object pooling, memory-mapped file loading, and zero-copy operations.
use memmap2::{Mmap, MmapOptions};
use std::collections::VecDeque;
use std::fs::File;
use std::path::Path;
use std::sync::{Arc, Mutex};
use super::memory_opt_enabled;
use crate::error::{Result, ScipixError};
/// Object pool for reusable buffers
pub struct BufferPool<T> {
pool: Arc<Mutex<VecDeque<T>>>,
factory: Arc<dyn Fn() -> T + Send + Sync>,
#[allow(dead_code)]
max_size: usize,
}
impl<T: Send + 'static> BufferPool<T> {
/// Create a new buffer pool
pub fn new<F>(factory: F, initial_size: usize, max_size: usize) -> Self
where
F: Fn() -> T + Send + Sync + 'static,
{
let factory = Arc::new(factory);
let pool = Arc::new(Mutex::new(VecDeque::with_capacity(max_size)));
// Pre-allocate initial buffers
if memory_opt_enabled() {
let mut pool_lock = pool.lock().unwrap();
for _ in 0..initial_size {
pool_lock.push_back(factory());
}
}
Self {
pool,
factory,
max_size,
}
}
/// Acquire a buffer from the pool
pub fn acquire(&self) -> PooledBuffer<T> {
let buffer = if memory_opt_enabled() {
self.pool
.lock()
.unwrap()
.pop_front()
.unwrap_or_else(|| (self.factory)())
} else {
(self.factory)()
};
PooledBuffer {
buffer: Some(buffer),
pool: self.pool.clone(),
}
}
/// Get current pool size
pub fn size(&self) -> usize {
self.pool.lock().unwrap().len()
}
/// Clear the pool
pub fn clear(&self) {
self.pool.lock().unwrap().clear();
}
}
/// RAII guard for pooled buffers
pub struct PooledBuffer<T> {
buffer: Option<T>,
pool: Arc<Mutex<VecDeque<T>>>,
}
impl<T> PooledBuffer<T> {
/// Get mutable reference to buffer
pub fn get_mut(&mut self) -> &mut T {
self.buffer.as_mut().unwrap()
}
/// Get immutable reference to buffer
pub fn get(&self) -> &T {
self.buffer.as_ref().unwrap()
}
}
impl<T> Drop for PooledBuffer<T> {
fn drop(&mut self) {
if memory_opt_enabled() {
if let Some(buffer) = self.buffer.take() {
let mut pool = self.pool.lock().unwrap();
pool.push_back(buffer);
}
}
}
}
impl<T> std::ops::Deref for PooledBuffer<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.buffer.as_ref().unwrap()
}
}
impl<T> std::ops::DerefMut for PooledBuffer<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.buffer.as_mut().unwrap()
}
}
/// Memory-mapped model file
pub struct MmapModel {
_mmap: Mmap,
data: *const u8,
len: usize,
}
unsafe impl Send for MmapModel {}
unsafe impl Sync for MmapModel {}
impl MmapModel {
/// Load model from file using memory mapping
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
let file = File::open(path.as_ref()).map_err(|e| ScipixError::Io(e))?;
let mmap = unsafe {
MmapOptions::new()
.map(&file)
.map_err(|e| ScipixError::Io(e))?
};
let data = mmap.as_ptr();
let len = mmap.len();
Ok(Self {
_mmap: mmap,
data,
len,
})
}
/// Get slice of model data
pub fn as_slice(&self) -> &[u8] {
unsafe { std::slice::from_raw_parts(self.data, self.len) }
}
/// Get size of mapped region
pub fn len(&self) -> usize {
self.len
}
/// Check if empty
pub fn is_empty(&self) -> bool {
self.len == 0
}
}
/// Zero-copy image view
pub struct ImageView<'a> {
data: &'a [u8],
width: u32,
height: u32,
channels: u8,
}
impl<'a> ImageView<'a> {
/// Create new image view from raw data
pub fn new(data: &'a [u8], width: u32, height: u32, channels: u8) -> Result<Self> {
let expected_len = (width * height * channels as u32) as usize;
if data.len() != expected_len {
return Err(ScipixError::InvalidInput(format!(
"Invalid data length: expected {}, got {}",
expected_len,
data.len()
)));
}
Ok(Self {
data,
width,
height,
channels,
})
}
/// Get pixel at (x, y)
pub fn pixel(&self, x: u32, y: u32) -> &[u8] {
let offset = ((y * self.width + x) * self.channels as u32) as usize;
&self.data[offset..offset + self.channels as usize]
}
/// Get raw data slice
pub fn data(&self) -> &[u8] {
self.data
}
/// Get dimensions
pub fn dimensions(&self) -> (u32, u32) {
(self.width, self.height)
}
/// Get number of channels
pub fn channels(&self) -> u8 {
self.channels
}
/// Create subview (region of interest)
pub fn subview(&self, x: u32, y: u32, width: u32, height: u32) -> Result<Self> {
if x + width > self.width || y + height > self.height {
return Err(ScipixError::InvalidInput(
"Subview out of bounds".to_string(),
));
}
// For simplicity, this creates a copy. True zero-copy would need stride support
let mut subview_data = Vec::new();
for row in y..y + height {
let start = ((row * self.width + x) * self.channels as u32) as usize;
let end = start + (width * self.channels as u32) as usize;
subview_data.extend_from_slice(&self.data[start..end]);
}
// This temporarily leaks memory - in production, use arena allocator
let leaked = Box::leak(subview_data.into_boxed_slice());
Ok(Self {
data: leaked,
width,
height,
channels: self.channels,
})
}
}
/// Arena allocator for temporary allocations
pub struct Arena {
buffer: Vec<u8>,
offset: usize,
}
impl Arena {
/// Create new arena with capacity
pub fn with_capacity(capacity: usize) -> Self {
Self {
buffer: Vec::with_capacity(capacity),
offset: 0,
}
}
/// Allocate aligned memory
pub fn alloc(&mut self, size: usize, align: usize) -> &mut [u8] {
// Align offset
let padding = (align - (self.offset % align)) % align;
self.offset += padding;
let start = self.offset;
let end = start + size;
if end > self.buffer.capacity() {
// Grow buffer
self.buffer.reserve(end - self.buffer.len());
}
unsafe {
self.buffer.set_len(end);
}
self.offset = end;
&mut self.buffer[start..end]
}
/// Reset arena (keeps capacity)
pub fn reset(&mut self) {
self.offset = 0;
self.buffer.clear();
}
/// Get current usage
pub fn usage(&self) -> usize {
self.offset
}
/// Get capacity
pub fn capacity(&self) -> usize {
self.buffer.capacity()
}
}
/// Global buffer pools for common sizes
pub struct GlobalPools {
small: BufferPool<Vec<u8>>, // 1KB buffers
medium: BufferPool<Vec<u8>>, // 64KB buffers
large: BufferPool<Vec<u8>>, // 1MB buffers
}
impl GlobalPools {
fn new() -> Self {
Self {
small: BufferPool::new(|| Vec::with_capacity(1024), 10, 100),
medium: BufferPool::new(|| Vec::with_capacity(64 * 1024), 5, 50),
large: BufferPool::new(|| Vec::with_capacity(1024 * 1024), 2, 20),
}
}
/// Get the global pools instance
pub fn get() -> &'static Self {
static POOLS: std::sync::OnceLock<GlobalPools> = std::sync::OnceLock::new();
POOLS.get_or_init(GlobalPools::new)
}
/// Acquire small buffer (1KB)
pub fn acquire_small(&self) -> PooledBuffer<Vec<u8>> {
self.small.acquire()
}
/// Acquire medium buffer (64KB)
pub fn acquire_medium(&self) -> PooledBuffer<Vec<u8>> {
self.medium.acquire()
}
/// Acquire large buffer (1MB)
pub fn acquire_large(&self) -> PooledBuffer<Vec<u8>> {
self.large.acquire()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_buffer_pool() {
let pool = BufferPool::new(|| Vec::with_capacity(1024), 2, 10);
assert_eq!(pool.size(), 2);
let mut buf1 = pool.acquire();
assert_eq!(buf1.capacity(), 1024);
buf1.extend_from_slice(b"test");
drop(buf1);
assert_eq!(pool.size(), 3); // Returned to pool
}
#[test]
fn test_mmap_model() {
let mut temp = NamedTempFile::new().unwrap();
temp.write_all(b"test model data").unwrap();
temp.flush().unwrap();
let mmap = MmapModel::from_file(temp.path()).unwrap();
assert_eq!(mmap.as_slice(), b"test model data");
assert_eq!(mmap.len(), 15);
}
#[test]
fn test_image_view() {
let data = vec![
255, 0, 0, 255, // Red pixel
0, 255, 0, 255, // Green pixel
0, 0, 255, 255, // Blue pixel
255, 255, 255, 255, // White pixel
];
let view = ImageView::new(&data, 2, 2, 4).unwrap();
assert_eq!(view.dimensions(), (2, 2));
assert_eq!(view.pixel(0, 0), &[255, 0, 0, 255]);
assert_eq!(view.pixel(1, 1), &[255, 255, 255, 255]);
}
#[test]
fn test_arena() {
let mut arena = Arena::with_capacity(1024);
let slice1 = arena.alloc(100, 8);
assert_eq!(slice1.len(), 100);
let slice2 = arena.alloc(200, 8);
assert_eq!(slice2.len(), 200);
assert!(arena.usage() >= 300);
arena.reset();
assert_eq!(arena.usage(), 0);
}
#[test]
fn test_global_pools() {
let pools = GlobalPools::get();
let small = pools.acquire_small();
assert!(small.capacity() >= 1024);
let medium = pools.acquire_medium();
assert!(medium.capacity() >= 64 * 1024);
let large = pools.acquire_large();
assert!(large.capacity() >= 1024 * 1024);
}
}

View File

@@ -0,0 +1,169 @@
//! Performance optimization utilities for scipix OCR
//!
//! This module provides runtime feature detection and optimized code paths
//! for different CPU architectures and capabilities.
pub mod batch;
pub mod memory;
pub mod parallel;
pub mod quantize;
pub mod simd;
use std::sync::OnceLock;
/// CPU features detected at runtime
#[derive(Debug, Clone, Copy)]
pub struct CpuFeatures {
pub avx2: bool,
pub avx512f: bool,
pub neon: bool,
pub sse4_2: bool,
}
static CPU_FEATURES: OnceLock<CpuFeatures> = OnceLock::new();
/// Detect CPU features at runtime
pub fn detect_features() -> CpuFeatures {
*CPU_FEATURES.get_or_init(|| {
#[cfg(target_arch = "x86_64")]
{
CpuFeatures {
avx2: is_x86_feature_detected!("avx2"),
avx512f: is_x86_feature_detected!("avx512f"),
neon: false,
sse4_2: is_x86_feature_detected!("sse4.2"),
}
}
#[cfg(target_arch = "aarch64")]
{
CpuFeatures {
avx2: false,
avx512f: false,
neon: std::arch::is_aarch64_feature_detected!("neon"),
sse4_2: false,
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
CpuFeatures {
avx2: false,
avx512f: false,
neon: false,
sse4_2: false,
}
}
})
}
/// Get the detected CPU features
pub fn get_features() -> CpuFeatures {
detect_features()
}
/// Runtime dispatch to optimized implementation
pub trait OptimizedOp<T> {
/// Execute the operation with the best available implementation
fn execute(&self, input: T) -> T;
/// Execute with SIMD if available, fallback to scalar
fn execute_auto(&self, input: T) -> T {
let features = get_features();
if features.avx2 || features.avx512f || features.neon {
self.execute_simd(input)
} else {
self.execute_scalar(input)
}
}
/// SIMD implementation
fn execute_simd(&self, input: T) -> T;
/// Scalar fallback implementation
fn execute_scalar(&self, input: T) -> T;
}
/// Optimization level configuration
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OptLevel {
/// No optimizations, scalar code only
None,
/// Use SIMD when available
Simd,
/// Use SIMD + parallel processing
Parallel,
/// All optimizations including memory optimizations
Full,
}
impl Default for OptLevel {
fn default() -> Self {
OptLevel::Full
}
}
/// Global optimization configuration
static OPT_LEVEL: OnceLock<OptLevel> = OnceLock::new();
/// Set the optimization level
pub fn set_opt_level(level: OptLevel) {
OPT_LEVEL.set(level).ok();
}
/// Get the current optimization level
pub fn get_opt_level() -> OptLevel {
*OPT_LEVEL.get_or_init(OptLevel::default)
}
/// Check if SIMD optimizations are enabled
pub fn simd_enabled() -> bool {
matches!(
get_opt_level(),
OptLevel::Simd | OptLevel::Parallel | OptLevel::Full
)
}
/// Check if parallel optimizations are enabled
pub fn parallel_enabled() -> bool {
matches!(get_opt_level(), OptLevel::Parallel | OptLevel::Full)
}
/// Check if memory optimizations are enabled
pub fn memory_opt_enabled() -> bool {
matches!(get_opt_level(), OptLevel::Full)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_feature_detection() {
let features = detect_features();
println!("Detected features: {:?}", features);
// Should always succeed on any platform
assert!(
features.avx2
|| features.avx512f
|| features.neon
|| features.sse4_2
|| (!features.avx2 && !features.avx512f && !features.neon && !features.sse4_2)
);
}
#[test]
fn test_opt_level() {
assert_eq!(get_opt_level(), OptLevel::Full);
set_opt_level(OptLevel::Simd);
// Can't change after first init, should still be Full
assert_eq!(get_opt_level(), OptLevel::Full);
}
#[test]
fn test_optimization_checks() {
assert!(simd_enabled());
assert!(parallel_enabled());
assert!(memory_opt_enabled());
}
}

View File

@@ -0,0 +1,335 @@
//! Parallel processing utilities for OCR pipeline
//!
//! Provides parallel image preprocessing, batch OCR, and pipelined execution.
use image::DynamicImage;
use rayon::prelude::*;
use std::sync::Arc;
use tokio::sync::Semaphore;
use super::parallel_enabled;
/// Parallel preprocessing of multiple images
pub fn parallel_preprocess<F>(images: Vec<DynamicImage>, preprocess_fn: F) -> Vec<DynamicImage>
where
F: Fn(DynamicImage) -> DynamicImage + Sync + Send,
{
if !parallel_enabled() {
return images.into_iter().map(preprocess_fn).collect();
}
images.into_par_iter().map(preprocess_fn).collect()
}
/// Parallel processing with error handling
pub fn parallel_preprocess_result<F, E>(
images: Vec<DynamicImage>,
preprocess_fn: F,
) -> Vec<std::result::Result<DynamicImage, E>>
where
F: Fn(DynamicImage) -> std::result::Result<DynamicImage, E> + Sync + Send,
E: Send,
{
if !parallel_enabled() {
return images.into_iter().map(preprocess_fn).collect();
}
images.into_par_iter().map(preprocess_fn).collect()
}
/// Pipeline parallel execution for OCR workflow
///
/// Executes stages in a pipeline: preprocess | detect | recognize
/// Each stage can start processing the next item while previous stages
/// continue with subsequent items.
pub struct PipelineExecutor<T, U, V> {
stage1: Arc<dyn Fn(T) -> U + Send + Sync>,
stage2: Arc<dyn Fn(U) -> V + Send + Sync>,
}
impl<T, U, V> PipelineExecutor<T, U, V>
where
T: Send,
U: Send,
V: Send,
{
pub fn new<F1, F2>(stage1: F1, stage2: F2) -> Self
where
F1: Fn(T) -> U + Send + Sync + 'static,
F2: Fn(U) -> V + Send + Sync + 'static,
{
Self {
stage1: Arc::new(stage1),
stage2: Arc::new(stage2),
}
}
/// Execute pipeline on multiple inputs
pub fn execute_batch(&self, inputs: Vec<T>) -> Vec<V> {
if !parallel_enabled() {
return inputs
.into_iter()
.map(|input| {
let stage1_out = (self.stage1)(input);
(self.stage2)(stage1_out)
})
.collect();
}
inputs
.into_par_iter()
.map(|input| {
let stage1_out = (self.stage1)(input);
(self.stage2)(stage1_out)
})
.collect()
}
}
/// Three-stage pipeline executor
pub struct Pipeline3<T, U, V, W> {
stage1: Arc<dyn Fn(T) -> U + Send + Sync>,
stage2: Arc<dyn Fn(U) -> V + Send + Sync>,
stage3: Arc<dyn Fn(V) -> W + Send + Sync>,
}
impl<T, U, V, W> Pipeline3<T, U, V, W>
where
T: Send,
U: Send,
V: Send,
W: Send,
{
pub fn new<F1, F2, F3>(stage1: F1, stage2: F2, stage3: F3) -> Self
where
F1: Fn(T) -> U + Send + Sync + 'static,
F2: Fn(U) -> V + Send + Sync + 'static,
F3: Fn(V) -> W + Send + Sync + 'static,
{
Self {
stage1: Arc::new(stage1),
stage2: Arc::new(stage2),
stage3: Arc::new(stage3),
}
}
pub fn execute_batch(&self, inputs: Vec<T>) -> Vec<W> {
if !parallel_enabled() {
return inputs
.into_iter()
.map(|input| {
let out1 = (self.stage1)(input);
let out2 = (self.stage2)(out1);
(self.stage3)(out2)
})
.collect();
}
inputs
.into_par_iter()
.map(|input| {
let out1 = (self.stage1)(input);
let out2 = (self.stage2)(out1);
(self.stage3)(out2)
})
.collect()
}
}
/// Parallel map with configurable chunk size
pub fn parallel_map_chunked<T, U, F>(items: Vec<T>, chunk_size: usize, map_fn: F) -> Vec<U>
where
T: Send,
U: Send,
F: Fn(T) -> U + Sync + Send,
{
if !parallel_enabled() {
return items.into_iter().map(map_fn).collect();
}
items
.into_par_iter()
.with_min_len(chunk_size)
.map(map_fn)
.collect()
}
/// Async parallel executor with concurrency limit
pub struct AsyncParallelExecutor {
semaphore: Arc<Semaphore>,
}
impl AsyncParallelExecutor {
/// Create executor with maximum concurrency limit
pub fn new(max_concurrent: usize) -> Self {
Self {
semaphore: Arc::new(Semaphore::new(max_concurrent)),
}
}
/// Execute async tasks with concurrency limit
pub async fn execute<T, F, Fut>(&self, tasks: Vec<T>, executor: F) -> Vec<Fut::Output>
where
T: Send + 'static,
F: Fn(T) -> Fut + Send + Sync + Clone + 'static,
Fut: std::future::Future + Send + 'static,
Fut::Output: Send + 'static,
{
let mut handles = Vec::new();
for task in tasks {
let permit = self.semaphore.clone().acquire_owned().await.unwrap();
let executor = executor.clone();
let handle = tokio::spawn(async move {
let result = executor(task).await;
drop(permit); // Release semaphore
result
});
handles.push(handle);
}
// Wait for all tasks to complete
let mut results = Vec::new();
for handle in handles {
if let Ok(result) = handle.await {
results.push(result);
}
}
results
}
/// Execute with error handling
pub async fn execute_result<T, F, Fut, R, E>(
&self,
tasks: Vec<T>,
executor: F,
) -> Vec<std::result::Result<R, E>>
where
T: Send + 'static,
F: Fn(T) -> Fut + Send + Sync + Clone + 'static,
Fut: std::future::Future<Output = std::result::Result<R, E>> + Send + 'static,
R: Send + 'static,
E: Send + 'static,
{
let mut handles = Vec::new();
for task in tasks {
let permit = self.semaphore.clone().acquire_owned().await.unwrap();
let executor = executor.clone();
let handle = tokio::spawn(async move {
let result = executor(task).await;
drop(permit);
result
});
handles.push(handle);
}
let mut results = Vec::new();
for handle in handles {
match handle.await {
Ok(result) => results.push(result),
Err(_) => continue, // Task panicked
}
}
results
}
}
/// Work-stealing parallel iterator for unbalanced workloads
pub fn parallel_unbalanced<T, U, F>(items: Vec<T>, map_fn: F) -> Vec<U>
where
T: Send,
U: Send,
F: Fn(T) -> U + Sync + Send,
{
if !parallel_enabled() {
return items.into_iter().map(map_fn).collect();
}
// Use adaptive strategy for unbalanced work
items
.into_par_iter()
.with_min_len(1) // Allow fine-grained work stealing
.map(map_fn)
.collect()
}
/// Get optimal thread count for current system
pub fn optimal_thread_count() -> usize {
rayon::current_num_threads()
}
/// Set global thread pool size
pub fn set_thread_count(threads: usize) {
rayon::ThreadPoolBuilder::new()
.num_threads(threads)
.build_global()
.ok();
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parallel_map() {
let data: Vec<i32> = (0..100).collect();
let result = parallel_map_chunked(data, 10, |x| x * 2);
assert_eq!(result.len(), 100);
assert_eq!(result[0], 0);
assert_eq!(result[50], 100);
assert_eq!(result[99], 198);
}
#[test]
fn test_pipeline_executor() {
let pipeline = PipelineExecutor::new(|x: i32| x + 1, |x: i32| x * 2);
let inputs = vec![1, 2, 3, 4, 5];
let results = pipeline.execute_batch(inputs);
assert_eq!(results, vec![4, 6, 8, 10, 12]);
}
#[test]
fn test_pipeline3() {
let pipeline = Pipeline3::new(|x: i32| x + 1, |x: i32| x * 2, |x: i32| x - 1);
let inputs = vec![1, 2, 3];
let results = pipeline.execute_batch(inputs);
// (1+1)*2-1 = 3, (2+1)*2-1 = 5, (3+1)*2-1 = 7
assert_eq!(results, vec![3, 5, 7]);
}
#[tokio::test]
async fn test_async_executor() {
let executor = AsyncParallelExecutor::new(2);
let tasks = vec![1, 2, 3, 4, 5];
let results = executor
.execute(tasks, |x| async move {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
x * 2
})
.await;
assert_eq!(results.len(), 5);
assert!(results.contains(&2));
assert!(results.contains(&10));
}
#[test]
fn test_optimal_threads() {
let threads = optimal_thread_count();
assert!(threads > 0);
assert!(threads <= num_cpus::get());
}
}

View File

@@ -0,0 +1,339 @@
//! Model quantization utilities
//!
//! Provides INT8 quantization for model weights and activations to reduce
//! memory usage and improve inference speed.
use std::f32;
/// Quantization parameters
#[derive(Debug, Clone, Copy)]
pub struct QuantParams {
pub scale: f32,
pub zero_point: i8,
}
impl QuantParams {
/// Calculate quantization parameters from min/max values
pub fn from_range(min: f32, max: f32) -> Self {
let qmin = i8::MIN as f32;
let qmax = i8::MAX as f32;
let scale = (max - min) / (qmax - qmin);
let zero_point = (qmin - min / scale).round() as i8;
Self { scale, zero_point }
}
/// Calculate from data statistics
pub fn from_data(data: &[f32]) -> Self {
let min = data.iter().copied().fold(f32::INFINITY, f32::min);
let max = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
Self::from_range(min, max)
}
/// Symmetric quantization (zero_point = 0)
pub fn symmetric(abs_max: f32) -> Self {
let scale = abs_max / 127.0;
Self {
scale,
zero_point: 0,
}
}
}
/// Quantize f32 weights to i8
pub fn quantize_weights(weights: &[f32]) -> (Vec<i8>, QuantParams) {
let params = QuantParams::from_data(weights);
let quantized = quantize_with_params(weights, params);
(quantized, params)
}
/// Quantize with given parameters
pub fn quantize_with_params(weights: &[f32], params: QuantParams) -> Vec<i8> {
weights.iter().map(|&w| quantize_value(w, params)).collect()
}
/// Quantize single value
#[inline]
pub fn quantize_value(value: f32, params: QuantParams) -> i8 {
let scaled = value / params.scale + params.zero_point as f32;
scaled.round().clamp(i8::MIN as f32, i8::MAX as f32) as i8
}
/// Dequantize i8 to f32
pub fn dequantize(quantized: &[i8], params: QuantParams) -> Vec<f32> {
quantized
.iter()
.map(|&q| dequantize_value(q, params))
.collect()
}
/// Dequantize single value
#[inline]
pub fn dequantize_value(quantized: i8, params: QuantParams) -> f32 {
(quantized as f32 - params.zero_point as f32) * params.scale
}
/// Quantized tensor representation
pub struct QuantizedTensor {
pub data: Vec<i8>,
pub params: QuantParams,
pub shape: Vec<usize>,
}
impl QuantizedTensor {
/// Create from f32 tensor
pub fn from_f32(data: &[f32], shape: Vec<usize>) -> Self {
let (quantized, params) = quantize_weights(data);
Self {
data: quantized,
params,
shape,
}
}
/// Create with symmetric quantization
pub fn from_f32_symmetric(data: &[f32], shape: Vec<usize>) -> Self {
let abs_max = data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
let params = QuantParams::symmetric(abs_max);
let quantized = quantize_with_params(data, params);
Self {
data: quantized,
params,
shape,
}
}
/// Dequantize to f32
pub fn to_f32(&self) -> Vec<f32> {
dequantize(&self.data, self.params)
}
/// Get size in bytes
pub fn size_bytes(&self) -> usize {
self.data.len()
+ std::mem::size_of::<QuantParams>()
+ self.shape.len() * std::mem::size_of::<usize>()
}
/// Calculate memory savings vs f32
pub fn compression_ratio(&self) -> f32 {
let f32_size = self.data.len() * std::mem::size_of::<f32>();
let quantized_size = self.size_bytes();
f32_size as f32 / quantized_size as f32
}
}
/// Per-channel quantization for conv/linear layers
pub struct PerChannelQuant {
pub data: Vec<i8>,
pub params: Vec<QuantParams>,
pub shape: Vec<usize>,
}
impl PerChannelQuant {
/// Quantize with per-channel parameters
/// For a weight tensor of shape [out_channels, in_channels, ...],
/// use separate params for each output channel
pub fn from_f32(data: &[f32], shape: Vec<usize>) -> Self {
if shape.is_empty() {
panic!("Shape cannot be empty");
}
let out_channels = shape[0];
let channel_size = data.len() / out_channels;
let mut all_quantized = Vec::with_capacity(data.len());
let mut params = Vec::with_capacity(out_channels);
for ch in 0..out_channels {
let start = ch * channel_size;
let end = start + channel_size;
let channel_data = &data[start..end];
let ch_params = QuantParams::from_data(channel_data);
let ch_quantized = quantize_with_params(channel_data, ch_params);
all_quantized.extend(ch_quantized);
params.push(ch_params);
}
Self {
data: all_quantized,
params,
shape,
}
}
/// Dequantize to f32
pub fn to_f32(&self) -> Vec<f32> {
let out_channels = self.shape[0];
let channel_size = self.data.len() / out_channels;
let mut result = Vec::with_capacity(self.data.len());
for ch in 0..out_channels {
let start = ch * channel_size;
let end = start + channel_size;
let channel_data = &self.data[start..end];
let ch_params = self.params[ch];
result.extend(dequantize(channel_data, ch_params));
}
result
}
}
/// Dynamic quantization - quantize at runtime
pub struct DynamicQuantizer {
percentile: f32,
}
impl DynamicQuantizer {
/// Create quantizer with calibration percentile
/// percentile: clip values beyond this percentile (e.g., 99.9)
pub fn new(percentile: f32) -> Self {
Self { percentile }
}
/// Quantize with calibration
pub fn quantize(&self, data: &[f32]) -> (Vec<i8>, QuantParams) {
let mut sorted: Vec<f32> = data.iter().copied().collect();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
let idx = ((sorted.len() as f32 * self.percentile / 100.0) as usize).min(sorted.len() - 1);
let min = -sorted[sorted.len() - idx];
let max = sorted[idx];
let params = QuantParams::from_range(min, max);
let quantized = quantize_with_params(data, params);
(quantized, params)
}
}
/// Calculate quantization error (MSE)
pub fn quantization_error(original: &[f32], quantized: &[i8], params: QuantParams) -> f32 {
let dequantized = dequantize(quantized, params);
let mse: f32 = original
.iter()
.zip(dequantized.iter())
.map(|(o, d)| (o - d).powi(2))
.sum::<f32>()
/ original.len() as f32;
mse
}
/// Calculate signal-to-quantization-noise ratio (SQNR) in dB
pub fn sqnr(original: &[f32], quantized: &[i8], params: QuantParams) -> f32 {
let dequantized = dequantize(quantized, params);
let signal_power: f32 = original.iter().map(|x| x.powi(2)).sum::<f32>() / original.len() as f32;
let noise_power: f32 = original
.iter()
.zip(dequantized.iter())
.map(|(o, d)| (o - d).powi(2))
.sum::<f32>()
/ original.len() as f32;
10.0 * (signal_power / noise_power).log10()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_quantize_dequantize() {
let weights = vec![0.0, 0.5, 1.0, -0.5, -1.0];
let (quantized, params) = quantize_weights(&weights);
let dequantized = dequantize(&quantized, params);
// Check approximate equality
for (orig, deq) in weights.iter().zip(dequantized.iter()) {
assert!((orig - deq).abs() < 0.01, "orig: {}, deq: {}", orig, deq);
}
}
#[test]
fn test_symmetric_quantization() {
let data = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
let params = QuantParams::symmetric(1.0);
assert_eq!(params.zero_point, 0);
assert!((params.scale - 1.0 / 127.0).abs() < 1e-6);
let quantized = quantize_with_params(&data, params);
assert_eq!(quantized[2], 0); // 0.0 should map to 0
}
#[test]
fn test_quantized_tensor() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let tensor = QuantizedTensor::from_f32(&data, vec![2, 2]);
assert_eq!(tensor.shape, vec![2, 2]);
assert_eq!(tensor.data.len(), 4);
let dequantized = tensor.to_f32();
for (orig, deq) in data.iter().zip(dequantized.iter()) {
assert!((orig - deq).abs() < 0.1);
}
}
#[test]
fn test_per_channel_quant() {
// 2 channels, 3 values each
let data = vec![
1.0, 2.0, 3.0, // Channel 0
10.0, 20.0, 30.0, // Channel 1
];
let quant = PerChannelQuant::from_f32(&data, vec![2, 3]);
assert_eq!(quant.params.len(), 2);
let dequantized = quant.to_f32();
for (orig, deq) in data.iter().zip(dequantized.iter()) {
assert!((orig - deq).abs() < 1.0);
}
}
#[test]
fn test_quantization_error() {
let original = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let (quantized, params) = quantize_weights(&original);
let error = quantization_error(&original, &quantized, params);
assert!(error < 0.1); // Should be small for simple data
let snr = sqnr(&original, &quantized, params);
assert!(snr > 30.0); // Should have good SNR
}
#[test]
fn test_compression_ratio() {
let data: Vec<f32> = (0..1000).map(|i| i as f32 / 1000.0).collect();
let tensor = QuantizedTensor::from_f32(&data, vec![1000]);
let ratio = tensor.compression_ratio();
assert!(ratio > 3.5); // Should be ~4x compression
}
#[test]
fn test_dynamic_quantizer() {
let mut data: Vec<f32> = (0..100).map(|i| i as f32).collect();
data.push(1000.0); // Outlier
let quantizer = DynamicQuantizer::new(99.0);
let (quantized, params) = quantizer.quantize(&data);
assert_eq!(quantized.len(), 101);
// The outlier should be clipped
assert!(params.scale > 0.0);
}
}

View File

@@ -0,0 +1,597 @@
//! SIMD-accelerated image processing operations
//!
//! Provides optimized implementations for common image operations using
//! AVX2, AVX-512, and ARM NEON intrinsics.
use super::{get_features, simd_enabled};
/// Convert RGBA image to grayscale using optimized SIMD operations
pub fn simd_grayscale(rgba: &[u8], gray: &mut [u8]) {
if !simd_enabled() {
return scalar_grayscale(rgba, gray);
}
let features = get_features();
#[cfg(target_arch = "x86_64")]
{
if features.avx2 {
unsafe { avx2_grayscale(rgba, gray) }
} else if features.sse4_2 {
unsafe { sse_grayscale(rgba, gray) }
} else {
scalar_grayscale(rgba, gray)
}
}
#[cfg(target_arch = "aarch64")]
{
if features.neon {
unsafe { neon_grayscale(rgba, gray) }
} else {
scalar_grayscale(rgba, gray)
}
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
{
scalar_grayscale(rgba, gray)
}
}
/// Scalar fallback for grayscale conversion
fn scalar_grayscale(rgba: &[u8], gray: &mut [u8]) {
assert_eq!(
rgba.len() / 4,
gray.len(),
"RGBA length must be 4x grayscale length"
);
for (i, chunk) in rgba.chunks_exact(4).enumerate() {
let r = chunk[0] as u32;
let g = chunk[1] as u32;
let b = chunk[2] as u32;
// ITU-R BT.601 luma coefficients: 0.299 R + 0.587 G + 0.114 B
gray[i] = ((r * 77 + g * 150 + b * 29) >> 8) as u8;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn avx2_grayscale(rgba: &[u8], gray: &mut [u8]) {
use std::arch::x86_64::*;
let len = gray.len();
let mut i = 0;
// Process 8 pixels at a time (32 RGBA bytes)
while i + 8 <= len {
// Load 32 bytes (8 RGBA pixels)
let rgba_ptr = rgba.as_ptr().add(i * 4);
let _pixels = _mm256_loadu_si256(rgba_ptr as *const __m256i);
// Separate RGBA channels (simplified - actual implementation would use shuffles)
// For production, use proper channel extraction
// Store grayscale result
for j in 0..8 {
let pixel_idx = (i + j) * 4;
let r = *rgba.get_unchecked(pixel_idx) as u32;
let g = *rgba.get_unchecked(pixel_idx + 1) as u32;
let b = *rgba.get_unchecked(pixel_idx + 2) as u32;
*gray.get_unchecked_mut(i + j) = ((r * 77 + g * 150 + b * 29) >> 8) as u8;
}
i += 8;
}
// Handle remaining pixels
scalar_grayscale(&rgba[i * 4..], &mut gray[i..]);
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "sse4.2")]
unsafe fn sse_grayscale(rgba: &[u8], gray: &mut [u8]) {
#[allow(unused_imports)]
use std::arch::x86_64::*;
let len = gray.len();
let mut i = 0;
// Process 4 pixels at a time (16 RGBA bytes)
while i + 4 <= len {
for j in 0..4 {
let pixel_idx = (i + j) * 4;
let r = *rgba.get_unchecked(pixel_idx) as u32;
let g = *rgba.get_unchecked(pixel_idx + 1) as u32;
let b = *rgba.get_unchecked(pixel_idx + 2) as u32;
*gray.get_unchecked_mut(i + j) = ((r * 77 + g * 150 + b * 29) >> 8) as u8;
}
i += 4;
}
scalar_grayscale(&rgba[i * 4..], &mut gray[i..]);
}
#[cfg(target_arch = "aarch64")]
unsafe fn neon_grayscale(rgba: &[u8], gray: &mut [u8]) {
use std::arch::aarch64::*;
let len = gray.len();
let mut i = 0;
// Process 8 pixels at a time
while i + 8 <= len {
for j in 0..8 {
let idx = (i + j) * 4;
let r = *rgba.get_unchecked(idx) as u32;
let g = *rgba.get_unchecked(idx + 1) as u32;
let b = *rgba.get_unchecked(idx + 2) as u32;
*gray.get_unchecked_mut(i + j) = ((r * 77 + g * 150 + b * 29) >> 8) as u8;
}
i += 8;
}
scalar_grayscale(&rgba[i * 4..], &mut gray[i..]);
}
/// Apply threshold to grayscale image using SIMD
pub fn simd_threshold(gray: &[u8], thresh: u8, out: &mut [u8]) {
if !simd_enabled() {
return scalar_threshold(gray, thresh, out);
}
let features = get_features();
#[cfg(target_arch = "x86_64")]
{
if features.avx2 {
unsafe { avx2_threshold(gray, thresh, out) }
} else {
scalar_threshold(gray, thresh, out)
}
}
#[cfg(not(target_arch = "x86_64"))]
{
scalar_threshold(gray, thresh, out)
}
}
fn scalar_threshold(gray: &[u8], thresh: u8, out: &mut [u8]) {
for (g, o) in gray.iter().zip(out.iter_mut()) {
*o = if *g >= thresh { 255 } else { 0 };
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn avx2_threshold(gray: &[u8], thresh: u8, out: &mut [u8]) {
use std::arch::x86_64::*;
let len = gray.len();
let mut i = 0;
let thresh_vec = _mm256_set1_epi8(thresh as i8);
let ones = _mm256_set1_epi8(-1); // 0xFF
// Process 32 bytes at a time
while i + 32 <= len {
let gray_vec = _mm256_loadu_si256(gray.as_ptr().add(i) as *const __m256i);
let cmp = _mm256_cmpgt_epi8(gray_vec, thresh_vec);
let result = _mm256_and_si256(cmp, ones);
_mm256_storeu_si256(out.as_mut_ptr().add(i) as *mut __m256i, result);
i += 32;
}
// Handle remaining bytes
scalar_threshold(&gray[i..], thresh, &mut out[i..]);
}
/// Normalize f32 tensor data using SIMD
pub fn simd_normalize(data: &mut [f32]) {
if !simd_enabled() {
return scalar_normalize(data);
}
let features = get_features();
#[cfg(target_arch = "x86_64")]
{
if features.avx2 {
unsafe { avx2_normalize(data) }
} else {
scalar_normalize(data)
}
}
#[cfg(not(target_arch = "x86_64"))]
{
scalar_normalize(data)
}
}
fn scalar_normalize(data: &mut [f32]) {
let sum: f32 = data.iter().sum();
let mean = sum / data.len() as f32;
let variance: f32 = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
let std_dev = variance.sqrt() + 1e-8; // Add epsilon for numerical stability
for x in data.iter_mut() {
*x = (*x - mean) / std_dev;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn avx2_normalize(data: &mut [f32]) {
use std::arch::x86_64::*;
// Calculate mean using SIMD
let len = data.len();
let mut sum = _mm256_setzero_ps();
let mut i = 0;
while i + 8 <= len {
let vals = _mm256_loadu_ps(data.as_ptr().add(i));
sum = _mm256_add_ps(sum, vals);
i += 8;
}
// Horizontal sum
let sum_scalar = {
let sum_arr: [f32; 8] = std::mem::transmute(sum);
sum_arr.iter().sum::<f32>() + data[i..].iter().sum::<f32>()
};
let mean = sum_scalar / len as f32;
let mean_vec = _mm256_set1_ps(mean);
// Calculate variance
let mut var_sum = _mm256_setzero_ps();
i = 0;
while i + 8 <= len {
let vals = _mm256_loadu_ps(data.as_ptr().add(i));
let diff = _mm256_sub_ps(vals, mean_vec);
let sq = _mm256_mul_ps(diff, diff);
var_sum = _mm256_add_ps(var_sum, sq);
i += 8;
}
let var_scalar = {
let var_arr: [f32; 8] = std::mem::transmute(var_sum);
var_arr.iter().sum::<f32>() + data[i..].iter().map(|x| (x - mean).powi(2)).sum::<f32>()
};
let std_dev = (var_scalar / len as f32).sqrt() + 1e-8;
let std_vec = _mm256_set1_ps(std_dev);
// Normalize
i = 0;
while i + 8 <= len {
let vals = _mm256_loadu_ps(data.as_ptr().add(i));
let centered = _mm256_sub_ps(vals, mean_vec);
let normalized = _mm256_div_ps(centered, std_vec);
_mm256_storeu_ps(data.as_mut_ptr().add(i), normalized);
i += 8;
}
// Handle remaining elements
for x in &mut data[i..] {
*x = (*x - mean) / std_dev;
}
}
/// Fast bilinear resize using SIMD - optimized for preprocessing
/// This is significantly faster than the image crate's resize for typical OCR sizes
pub fn simd_resize_bilinear(
src: &[u8],
src_width: usize,
src_height: usize,
dst_width: usize,
dst_height: usize,
) -> Vec<u8> {
if !simd_enabled() {
return scalar_resize_bilinear(src, src_width, src_height, dst_width, dst_height);
}
let features = get_features();
#[cfg(target_arch = "x86_64")]
{
if features.avx2 {
unsafe { avx2_resize_bilinear(src, src_width, src_height, dst_width, dst_height) }
} else {
scalar_resize_bilinear(src, src_width, src_height, dst_width, dst_height)
}
}
#[cfg(not(target_arch = "x86_64"))]
{
scalar_resize_bilinear(src, src_width, src_height, dst_width, dst_height)
}
}
/// Scalar bilinear resize implementation
fn scalar_resize_bilinear(
src: &[u8],
src_width: usize,
src_height: usize,
dst_width: usize,
dst_height: usize,
) -> Vec<u8> {
let mut dst = vec![0u8; dst_width * dst_height];
let x_scale = src_width as f32 / dst_width as f32;
let y_scale = src_height as f32 / dst_height as f32;
for y in 0..dst_height {
let src_y = y as f32 * y_scale;
let y0 = (src_y.floor() as usize).min(src_height - 1);
let y1 = (y0 + 1).min(src_height - 1);
let y_frac = src_y - src_y.floor();
for x in 0..dst_width {
let src_x = x as f32 * x_scale;
let x0 = (src_x.floor() as usize).min(src_width - 1);
let x1 = (x0 + 1).min(src_width - 1);
let x_frac = src_x - src_x.floor();
// Bilinear interpolation
let p00 = src[y0 * src_width + x0] as f32;
let p10 = src[y0 * src_width + x1] as f32;
let p01 = src[y1 * src_width + x0] as f32;
let p11 = src[y1 * src_width + x1] as f32;
let top = p00 * (1.0 - x_frac) + p10 * x_frac;
let bottom = p01 * (1.0 - x_frac) + p11 * x_frac;
let value = top * (1.0 - y_frac) + bottom * y_frac;
dst[y * dst_width + x] = value.round() as u8;
}
}
dst
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
unsafe fn avx2_resize_bilinear(
src: &[u8],
src_width: usize,
src_height: usize,
dst_width: usize,
dst_height: usize,
) -> Vec<u8> {
use std::arch::x86_64::*;
let mut dst = vec![0u8; dst_width * dst_height];
let x_scale = src_width as f32 / dst_width as f32;
let y_scale = src_height as f32 / dst_height as f32;
// Process 8 output pixels at a time for x dimension
for y in 0..dst_height {
let src_y = y as f32 * y_scale;
let y0 = (src_y.floor() as usize).min(src_height - 1);
let y1 = (y0 + 1).min(src_height - 1);
let _y_frac = _mm256_set1_ps(src_y - src_y.floor());
let _y_frac_inv = _mm256_set1_ps(1.0 - (src_y - src_y.floor()));
let mut x = 0;
while x + 8 <= dst_width {
// Calculate source x coordinates for 8 destination pixels
let src_xs: [f32; 8] = [
(x) as f32 * x_scale,
(x + 1) as f32 * x_scale,
(x + 2) as f32 * x_scale,
(x + 3) as f32 * x_scale,
(x + 4) as f32 * x_scale,
(x + 5) as f32 * x_scale,
(x + 6) as f32 * x_scale,
(x + 7) as f32 * x_scale,
];
let mut results = [0u8; 8];
for i in 0..8 {
let src_x = src_xs[i];
let x0 = (src_x.floor() as usize).min(src_width - 1);
let x1 = (x0 + 1).min(src_width - 1);
let x_frac = src_x - src_x.floor();
let p00 = *src.get_unchecked(y0 * src_width + x0) as f32;
let p10 = *src.get_unchecked(y0 * src_width + x1) as f32;
let p01 = *src.get_unchecked(y1 * src_width + x0) as f32;
let p11 = *src.get_unchecked(y1 * src_width + x1) as f32;
let top = p00 * (1.0 - x_frac) + p10 * x_frac;
let bottom = p01 * (1.0 - x_frac) + p11 * x_frac;
let value =
top * (1.0 - (src_y - src_y.floor())) + bottom * (src_y - src_y.floor());
results[i] = value.round() as u8;
}
for i in 0..8 {
*dst.get_unchecked_mut(y * dst_width + x + i) = results[i];
}
x += 8;
}
// Handle remaining pixels
while x < dst_width {
let src_x = x as f32 * x_scale;
let x0 = (src_x.floor() as usize).min(src_width - 1);
let x1 = (x0 + 1).min(src_width - 1);
let x_frac = src_x - src_x.floor();
let p00 = *src.get_unchecked(y0 * src_width + x0) as f32;
let p10 = *src.get_unchecked(y0 * src_width + x1) as f32;
let p01 = *src.get_unchecked(y1 * src_width + x0) as f32;
let p11 = *src.get_unchecked(y1 * src_width + x1) as f32;
let top = p00 * (1.0 - x_frac) + p10 * x_frac;
let bottom = p01 * (1.0 - x_frac) + p11 * x_frac;
let value = top * (1.0 - (src_y - src_y.floor())) + bottom * (src_y - src_y.floor());
*dst.get_unchecked_mut(y * dst_width + x) = value.round() as u8;
x += 1;
}
}
dst
}
/// Parallel SIMD resize for large images - splits work across threads
#[cfg(feature = "rayon")]
pub fn parallel_simd_resize(
src: &[u8],
src_width: usize,
src_height: usize,
dst_width: usize,
dst_height: usize,
) -> Vec<u8> {
use rayon::prelude::*;
// For small images, use single-threaded SIMD
if dst_height < 64 || dst_width * dst_height < 100_000 {
return simd_resize_bilinear(src, src_width, src_height, dst_width, dst_height);
}
let mut dst = vec![0u8; dst_width * dst_height];
let x_scale = src_width as f32 / dst_width as f32;
let y_scale = src_height as f32 / dst_height as f32;
// Process rows in parallel
dst.par_chunks_mut(dst_width)
.enumerate()
.for_each(|(y, row)| {
let src_y = y as f32 * y_scale;
let y0 = (src_y.floor() as usize).min(src_height - 1);
let y1 = (y0 + 1).min(src_height - 1);
let y_frac = src_y - src_y.floor();
for x in 0..dst_width {
let src_x = x as f32 * x_scale;
let x0 = (src_x.floor() as usize).min(src_width - 1);
let x1 = (x0 + 1).min(src_width - 1);
let x_frac = src_x - src_x.floor();
let p00 = src[y0 * src_width + x0] as f32;
let p10 = src[y0 * src_width + x1] as f32;
let p01 = src[y1 * src_width + x0] as f32;
let p11 = src[y1 * src_width + x1] as f32;
let top = p00 * (1.0 - x_frac) + p10 * x_frac;
let bottom = p01 * (1.0 - x_frac) + p11 * x_frac;
let value = top * (1.0 - y_frac) + bottom * y_frac;
row[x] = value.round() as u8;
}
});
dst
}
/// Ultra-fast area average downscaling for preprocessing
/// Best for large images being scaled down significantly
pub fn fast_area_resize(
src: &[u8],
src_width: usize,
src_height: usize,
dst_width: usize,
dst_height: usize,
) -> Vec<u8> {
// Only use area averaging for downscaling
if dst_width >= src_width || dst_height >= src_height {
return simd_resize_bilinear(src, src_width, src_height, dst_width, dst_height);
}
let mut dst = vec![0u8; dst_width * dst_height];
let x_ratio = src_width as f32 / dst_width as f32;
let y_ratio = src_height as f32 / dst_height as f32;
for y in 0..dst_height {
let y_start = (y as f32 * y_ratio) as usize;
let y_end = (((y + 1) as f32 * y_ratio) as usize).min(src_height);
for x in 0..dst_width {
let x_start = (x as f32 * x_ratio) as usize;
let x_end = (((x + 1) as f32 * x_ratio) as usize).min(src_width);
// Calculate area average
let mut sum: u32 = 0;
let mut count: u32 = 0;
for sy in y_start..y_end {
for sx in x_start..x_end {
sum += src[sy * src_width + sx] as u32;
count += 1;
}
}
dst[y * dst_width + x] = if count > 0 { (sum / count) as u8 } else { 0 };
}
}
dst
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_grayscale_conversion() {
let rgba = vec![
255, 0, 0, 255, // Red
0, 255, 0, 255, // Green
0, 0, 255, 255, // Blue
255, 255, 255, 255, // White
];
let mut gray = vec![0u8; 4];
simd_grayscale(&rgba, &mut gray);
// Check approximately correct values
assert!(gray[0] > 50 && gray[0] < 100); // Red
assert!(gray[1] > 130 && gray[1] < 160); // Green
assert!(gray[2] > 20 && gray[2] < 50); // Blue
assert_eq!(gray[3], 255); // White
}
#[test]
fn test_threshold() {
let gray = vec![0, 50, 100, 150, 200, 255];
let mut out = vec![0u8; 6];
simd_threshold(&gray, 100, &mut out);
assert_eq!(out, vec![0, 0, 0, 255, 255, 255]);
}
#[test]
fn test_normalize() {
let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
simd_normalize(&mut data);
// After normalization, mean should be ~0 and std dev ~1
let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
assert!(mean.abs() < 1e-6);
}
#[cfg(target_arch = "x86_64")]
#[test]
fn test_simd_vs_scalar_grayscale() {
let rgba: Vec<u8> = (0..1024).map(|i| (i % 256) as u8).collect();
let mut gray_simd = vec![0u8; 256];
let mut gray_scalar = vec![0u8; 256];
simd_grayscale(&rgba, &mut gray_simd);
scalar_grayscale(&rgba, &mut gray_scalar);
assert_eq!(gray_simd, gray_scalar);
}
}

View File

@@ -0,0 +1,298 @@
//! DOCX (Microsoft Word) formatter with Office Math ML support
//!
//! This is a stub implementation. Full DOCX generation requires:
//! - ZIP file creation for .docx format
//! - XML generation for document.xml, styles.xml, etc.
//! - Office Math ML for equations
//! - Image embedding support
//!
//! Consider using libraries like `docx-rs` for production implementation.
use super::{LineData, OcrResult};
use std::io::Write;
/// DOCX formatter (stub implementation)
#[allow(dead_code)]
pub struct DocxFormatter {
include_styles: bool,
page_size: PageSize,
margins: Margins,
}
#[derive(Debug, Clone, Copy)]
pub struct PageSize {
pub width: u32, // in twips (1/1440 inch)
pub height: u32,
}
impl PageSize {
pub fn letter() -> Self {
Self {
width: 12240, // 8.5 inches
height: 15840, // 11 inches
}
}
pub fn a4() -> Self {
Self {
width: 11906, // 210mm
height: 16838, // 297mm
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct Margins {
pub top: u32,
pub right: u32,
pub bottom: u32,
pub left: u32,
}
impl Margins {
pub fn normal() -> Self {
Self {
top: 1440, // 1 inch
right: 1440,
bottom: 1440,
left: 1440,
}
}
}
impl DocxFormatter {
pub fn new() -> Self {
Self {
include_styles: true,
page_size: PageSize::letter(),
margins: Margins::normal(),
}
}
pub fn with_page_size(mut self, page_size: PageSize) -> Self {
self.page_size = page_size;
self
}
pub fn with_margins(mut self, margins: Margins) -> Self {
self.margins = margins;
self
}
/// Generate Office Math ML from LaTeX
/// This is a simplified placeholder - real implementation needs proper conversion
pub fn latex_to_mathml(&self, latex: &str) -> String {
// This is a very simplified stub
// Real implementation would parse LaTeX and generate proper Office Math ML
format!(
r#"<m:oMathPara>
<m:oMath>
<m:r>
<m:t>{}</m:t>
</m:r>
</m:oMath>
</m:oMathPara>"#,
self.escape_xml(latex)
)
}
/// Generate document.xml content
pub fn generate_document_xml(&self, lines: &[LineData]) -> String {
let mut xml = String::from(
r#"<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<w:document xmlns:w="http://schemas.openxmlformats.org/wordprocessingml/2006/main"
xmlns:m="http://schemas.openxmlformats.org/officeDocument/2006/math">
<w:body>
"#,
);
for line in lines {
xml.push_str(&self.format_line(line));
}
xml.push_str(" </w:body>\n</w:document>");
xml
}
fn format_line(&self, line: &LineData) -> String {
match line.line_type.as_str() {
"text" => self.format_paragraph(&line.text),
"math" | "equation" => {
let latex = line.latex.as_ref().unwrap_or(&line.text);
self.format_math(latex)
}
"heading" => self.format_heading(&line.text, 1),
_ => self.format_paragraph(&line.text),
}
}
fn format_paragraph(&self, text: &str) -> String {
format!(
r#" <w:p>
<w:r>
<w:t>{}</w:t>
</w:r>
</w:p>
"#,
self.escape_xml(text)
)
}
fn format_heading(&self, text: &str, level: u32) -> String {
format!(
r#" <w:p>
<w:pPr>
<w:pStyle w:val="Heading{}"/>
</w:pPr>
<w:r>
<w:t>{}</w:t>
</w:r>
</w:p>
"#,
level,
self.escape_xml(text)
)
}
fn format_math(&self, latex: &str) -> String {
let mathml = self.latex_to_mathml(latex);
format!(
r#" <w:p>
<w:r>
{}
</w:r>
</w:p>
"#,
mathml
)
}
fn escape_xml(&self, text: &str) -> String {
text.replace('&', "&amp;")
.replace('<', "&lt;")
.replace('>', "&gt;")
.replace('"', "&quot;")
.replace('\'', "&apos;")
}
/// Save DOCX to file (stub - needs ZIP implementation)
pub fn save_to_file<W: Write>(
&self,
_writer: &mut W,
_result: &OcrResult,
) -> Result<(), String> {
Err("DOCX binary format generation not implemented. Use docx-rs library for full implementation.".to_string())
}
/// Generate styles.xml content
pub fn generate_styles_xml(&self) -> String {
r#"<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
<w:styles xmlns:w="http://schemas.openxmlformats.org/wordprocessingml/2006/main">
<w:style w:type="paragraph" w:styleId="Normal">
<w:name w:val="Normal"/>
<w:qFormat/>
</w:style>
<w:style w:type="paragraph" w:styleId="Heading1">
<w:name w:val="Heading 1"/>
<w:basedOn w:val="Normal"/>
<w:qFormat/>
<w:pPr>
<w:keepNext/>
<w:keepLines/>
</w:pPr>
<w:rPr>
<w:b/>
<w:sz w:val="32"/>
</w:rPr>
</w:style>
</w:styles>"#
.to_string()
}
}
impl Default for DocxFormatter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::output::BoundingBox;
#[test]
fn test_page_sizes() {
let letter = PageSize::letter();
assert_eq!(letter.width, 12240);
let a4 = PageSize::a4();
assert!(a4.width < letter.width);
}
#[test]
fn test_escape_xml() {
let formatter = DocxFormatter::new();
let result = formatter.escape_xml("Test <tag> & \"quote\"");
assert!(result.contains("&lt;"));
assert!(result.contains("&gt;"));
assert!(result.contains("&amp;"));
assert!(result.contains("&quot;"));
}
#[test]
fn test_format_paragraph() {
let formatter = DocxFormatter::new();
let result = formatter.format_paragraph("Hello World");
assert!(result.contains("<w:p>"));
assert!(result.contains("<w:t>Hello World</w:t>"));
}
#[test]
fn test_format_heading() {
let formatter = DocxFormatter::new();
let result = formatter.format_heading("Chapter 1", 1);
assert!(result.contains("Heading1"));
assert!(result.contains("Chapter 1"));
}
#[test]
fn test_latex_to_mathml() {
let formatter = DocxFormatter::new();
let result = formatter.latex_to_mathml("E = mc^2");
assert!(result.contains("<m:oMath>"));
assert!(result.contains("mc^2"));
}
#[test]
fn test_generate_document_xml() {
let formatter = DocxFormatter::new();
let lines = vec![LineData {
line_type: "text".to_string(),
text: "Hello".to_string(),
latex: None,
bbox: BoundingBox::new(0.0, 0.0, 100.0, 20.0),
confidence: 0.95,
words: None,
}];
let xml = formatter.generate_document_xml(&lines);
assert!(xml.contains("<?xml"));
assert!(xml.contains("<w:document"));
assert!(xml.contains("Hello"));
}
#[test]
fn test_generate_styles_xml() {
let formatter = DocxFormatter::new();
let xml = formatter.generate_styles_xml();
assert!(xml.contains("<w:styles"));
assert!(xml.contains("Normal"));
assert!(xml.contains("Heading 1"));
}
}

View File

@@ -0,0 +1,412 @@
//! Multi-format output formatter with batch processing and streaming support
use super::*;
use crate::output::{html, latex, mmd, smiles};
use std::io::Write;
/// Configuration for output formatting
#[derive(Debug, Clone)]
pub struct FormatterConfig {
/// Target output formats
pub formats: Vec<OutputFormat>,
/// Enable pretty printing (where applicable)
pub pretty: bool,
/// Include confidence scores in output
pub include_confidence: bool,
/// Include bounding box data
pub include_bbox: bool,
/// Math delimiter style for LaTeX/MMD
pub math_delimiters: MathDelimiters,
/// HTML rendering engine
pub html_engine: HtmlEngine,
/// Enable streaming for large documents
pub streaming: bool,
}
impl Default for FormatterConfig {
fn default() -> Self {
Self {
formats: vec![OutputFormat::Text],
pretty: true,
include_confidence: false,
include_bbox: false,
math_delimiters: MathDelimiters::default(),
html_engine: HtmlEngine::MathJax,
streaming: false,
}
}
}
/// Math delimiter configuration
#[derive(Debug, Clone)]
pub struct MathDelimiters {
pub inline_start: String,
pub inline_end: String,
pub display_start: String,
pub display_end: String,
}
impl Default for MathDelimiters {
fn default() -> Self {
Self {
inline_start: "$".to_string(),
inline_end: "$".to_string(),
display_start: "$$".to_string(),
display_end: "$$".to_string(),
}
}
}
/// HTML rendering engine options
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HtmlEngine {
MathJax,
KaTeX,
Raw,
}
/// Main output formatter
pub struct OutputFormatter {
config: FormatterConfig,
}
impl OutputFormatter {
/// Create a new formatter with default configuration
pub fn new() -> Self {
Self {
config: FormatterConfig::default(),
}
}
/// Create a formatter with custom configuration
pub fn with_config(config: FormatterConfig) -> Self {
Self { config }
}
/// Format a single OCR result
pub fn format_result(&self, result: &OcrResult) -> Result<FormatsData, String> {
let mut formats = FormatsData::default();
for format in &self.config.formats {
let output = self.format_single(result, *format)?;
self.set_format_output(&mut formats, *format, output);
}
Ok(formats)
}
/// Format multiple results in batch
pub fn format_batch(&self, results: &[OcrResult]) -> Result<Vec<FormatsData>, String> {
results
.iter()
.map(|result| self.format_result(result))
.collect()
}
/// Stream format results to a writer
pub fn format_stream<W: Write>(
&self,
results: &[OcrResult],
writer: &mut W,
format: OutputFormat,
) -> Result<(), String> {
for (i, result) in results.iter().enumerate() {
let output = self.format_single(result, format)?;
writer
.write_all(output.as_bytes())
.map_err(|e| format!("Write error: {}", e))?;
// Add separator between results
if i < results.len() - 1 {
writer
.write_all(b"\n\n---\n\n")
.map_err(|e| format!("Write error: {}", e))?;
}
}
Ok(())
}
/// Format a single result to a specific format
fn format_single(&self, result: &OcrResult, format: OutputFormat) -> Result<String, String> {
match format {
OutputFormat::Text => self.format_text(result),
OutputFormat::LaTeX => self.format_latex(result, false),
OutputFormat::LaTeXStyled => self.format_latex(result, true),
OutputFormat::Mmd => self.format_mmd(result),
OutputFormat::Html => self.format_html(result),
OutputFormat::Smiles => self.format_smiles(result),
OutputFormat::Docx => self.format_docx(result),
OutputFormat::MathML => self.format_mathml(result),
OutputFormat::AsciiMath => self.format_asciimath(result),
}
}
fn format_text(&self, result: &OcrResult) -> Result<String, String> {
if let Some(text) = &result.formats.text {
return Ok(text.clone());
}
// Fallback: extract text from line data
if let Some(line_data) = &result.line_data {
let text = line_data
.iter()
.map(|line| line.text.as_str())
.collect::<Vec<_>>()
.join("\n");
return Ok(text);
}
Err("No text content available".to_string())
}
fn format_latex(&self, result: &OcrResult, styled: bool) -> Result<String, String> {
let latex_content = if styled {
result
.formats
.latex_styled
.as_ref()
.or(result.formats.latex_normal.as_ref())
} else {
result.formats.latex_normal.as_ref()
};
if let Some(latex) = latex_content {
if styled {
// Wrap in document with packages
Ok(latex::LaTeXFormatter::new()
.with_packages(vec![
"amsmath".to_string(),
"amssymb".to_string(),
"graphicx".to_string(),
])
.format_document(latex))
} else {
Ok(latex.clone())
}
} else {
Err("No LaTeX content available".to_string())
}
}
fn format_mmd(&self, result: &OcrResult) -> Result<String, String> {
if let Some(mmd) = &result.formats.mmd {
return Ok(mmd.clone());
}
// Generate MMD from line data
if let Some(line_data) = &result.line_data {
let formatter = mmd::MmdFormatter::with_delimiters(self.config.math_delimiters.clone());
return Ok(formatter.format(line_data));
}
Err("No MMD content available".to_string())
}
fn format_html(&self, result: &OcrResult) -> Result<String, String> {
if let Some(html) = &result.formats.html {
return Ok(html.clone());
}
// Generate HTML with math rendering
let content = self.format_text(result)?;
let formatter = html::HtmlFormatter::new()
.with_engine(self.config.html_engine)
.with_styling(self.config.pretty);
Ok(formatter.format(&content, result.line_data.as_deref()))
}
fn format_smiles(&self, result: &OcrResult) -> Result<String, String> {
if let Some(smiles) = &result.formats.smiles {
return Ok(smiles.clone());
}
// Generate SMILES if we have chemical structure data
let generator = smiles::SmilesGenerator::new();
generator.generate_from_result(result)
}
fn format_docx(&self, _result: &OcrResult) -> Result<String, String> {
// DOCX requires binary format, return placeholder
Err("DOCX format requires binary output - use save_docx() instead".to_string())
}
fn format_mathml(&self, result: &OcrResult) -> Result<String, String> {
if let Some(mathml) = &result.formats.mathml {
return Ok(mathml.clone());
}
Err("MathML generation not yet implemented".to_string())
}
fn format_asciimath(&self, result: &OcrResult) -> Result<String, String> {
if let Some(asciimath) = &result.formats.asciimath {
return Ok(asciimath.clone());
}
Err("AsciiMath conversion not yet implemented".to_string())
}
fn set_format_output(&self, formats: &mut FormatsData, format: OutputFormat, output: String) {
match format {
OutputFormat::Text => formats.text = Some(output),
OutputFormat::LaTeX => formats.latex_normal = Some(output),
OutputFormat::LaTeXStyled => formats.latex_styled = Some(output),
OutputFormat::Mmd => formats.mmd = Some(output),
OutputFormat::Html => formats.html = Some(output),
OutputFormat::Smiles => formats.smiles = Some(output),
OutputFormat::MathML => formats.mathml = Some(output),
OutputFormat::AsciiMath => formats.asciimath = Some(output),
OutputFormat::Docx => {} // Binary format, handled separately
}
}
}
impl Default for OutputFormatter {
fn default() -> Self {
Self::new()
}
}
/// Builder for OutputFormatter configuration
pub struct FormatterBuilder {
config: FormatterConfig,
}
impl FormatterBuilder {
pub fn new() -> Self {
Self {
config: FormatterConfig::default(),
}
}
pub fn formats(mut self, formats: Vec<OutputFormat>) -> Self {
self.config.formats = formats;
self
}
pub fn add_format(mut self, format: OutputFormat) -> Self {
self.config.formats.push(format);
self
}
pub fn pretty(mut self, pretty: bool) -> Self {
self.config.pretty = pretty;
self
}
pub fn include_confidence(mut self, include: bool) -> Self {
self.config.include_confidence = include;
self
}
pub fn include_bbox(mut self, include: bool) -> Self {
self.config.include_bbox = include;
self
}
pub fn math_delimiters(mut self, delimiters: MathDelimiters) -> Self {
self.config.math_delimiters = delimiters;
self
}
pub fn html_engine(mut self, engine: HtmlEngine) -> Self {
self.config.html_engine = engine;
self
}
pub fn streaming(mut self, streaming: bool) -> Self {
self.config.streaming = streaming;
self
}
pub fn build(self) -> OutputFormatter {
OutputFormatter::with_config(self.config)
}
}
impl Default for FormatterBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_result() -> OcrResult {
OcrResult {
request_id: "test_123".to_string(),
version: "3.0".to_string(),
image_width: 800,
image_height: 600,
is_printed: true,
is_handwritten: false,
auto_rotate_confidence: 0.95,
auto_rotate_degrees: 0,
confidence: 0.98,
confidence_rate: 0.97,
formats: FormatsData {
text: Some("E = mc^2".to_string()),
latex_normal: Some(r"E = mc^2".to_string()),
..Default::default()
},
line_data: None,
error: None,
metadata: HashMap::new(),
}
}
#[test]
fn test_format_text() {
let formatter = OutputFormatter::new();
let result = create_test_result();
let output = formatter
.format_single(&result, OutputFormat::Text)
.unwrap();
assert_eq!(output, "E = mc^2");
}
#[test]
fn test_format_latex() {
let formatter = OutputFormatter::new();
let result = create_test_result();
let output = formatter
.format_single(&result, OutputFormat::LaTeX)
.unwrap();
assert!(output.contains("mc^2"));
}
#[test]
fn test_builder() {
let formatter = FormatterBuilder::new()
.add_format(OutputFormat::Text)
.add_format(OutputFormat::LaTeX)
.pretty(true)
.include_confidence(true)
.build();
assert_eq!(formatter.config.formats.len(), 2);
assert!(formatter.config.pretty);
assert!(formatter.config.include_confidence);
}
#[test]
fn test_batch_format() {
let formatter = OutputFormatter::new();
let results = vec![create_test_result(), create_test_result()];
let outputs = formatter.format_batch(&results).unwrap();
assert_eq!(outputs.len(), 2);
}
}

View File

@@ -0,0 +1,396 @@
//! HTML output formatter with math rendering support
use super::{HtmlEngine, LineData};
/// HTML formatter with math rendering
pub struct HtmlFormatter {
engine: HtmlEngine,
css_styling: bool,
accessibility: bool,
responsive: bool,
theme: HtmlTheme,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum HtmlTheme {
Light,
Dark,
Auto,
}
impl HtmlFormatter {
pub fn new() -> Self {
Self {
engine: HtmlEngine::MathJax,
css_styling: true,
accessibility: true,
responsive: true,
theme: HtmlTheme::Light,
}
}
pub fn with_engine(mut self, engine: HtmlEngine) -> Self {
self.engine = engine;
self
}
pub fn with_styling(mut self, styling: bool) -> Self {
self.css_styling = styling;
self
}
pub fn accessibility(mut self, enabled: bool) -> Self {
self.accessibility = enabled;
self
}
pub fn responsive(mut self, enabled: bool) -> Self {
self.responsive = enabled;
self
}
pub fn theme(mut self, theme: HtmlTheme) -> Self {
self.theme = theme;
self
}
/// Format content to HTML
pub fn format(&self, content: &str, lines: Option<&[LineData]>) -> String {
let mut html = String::new();
// HTML header with math rendering scripts
html.push_str(&self.html_header());
// Body start with theme class
html.push_str("<body");
if self.css_styling {
html.push_str(&format!(r#" class="theme-{:?}""#, self.theme).to_lowercase());
}
html.push_str(">\n");
// Main content container
html.push_str(r#"<div class="content">"#);
html.push_str("\n");
// Format content
if let Some(line_data) = lines {
html.push_str(&self.format_lines(line_data));
} else {
html.push_str(&self.format_text(content));
}
html.push_str("</div>\n");
html.push_str("</body>\n</html>");
html
}
/// Generate HTML header with scripts and styles
fn html_header(&self) -> String {
let mut header = String::from("<!DOCTYPE html>\n<html lang=\"en\">\n<head>\n");
header.push_str(r#" <meta charset="UTF-8">"#);
header.push_str("\n");
if self.responsive {
header.push_str(
r#" <meta name="viewport" content="width=device-width, initial-scale=1.0">"#,
);
header.push_str("\n");
}
header.push_str(" <title>Mathematical Content</title>\n");
// Math rendering scripts
match self.engine {
HtmlEngine::MathJax => {
header.push_str(r#" <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>"#);
header.push_str("\n");
header.push_str(r#" <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>"#);
header.push_str("\n");
header.push_str(" <script>\n");
header.push_str(" MathJax = {\n");
header.push_str(" tex: {\n");
header.push_str(r#" inlineMath: [['$', '$'], ['\\(', '\\)']],"#);
header.push_str("\n");
header.push_str(r#" displayMath: [['$$', '$$'], ['\\[', '\\]']]"#);
header.push_str("\n");
header.push_str(" }\n");
header.push_str(" };\n");
header.push_str(" </script>\n");
}
HtmlEngine::KaTeX => {
header.push_str(r#" <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.16.9/dist/katex.min.css">"#);
header.push_str("\n");
header.push_str(r#" <script defer src="https://cdn.jsdelivr.net/npm/katex@0.16.9/dist/katex.min.js"></script>"#);
header.push_str("\n");
header.push_str(r#" <script defer src="https://cdn.jsdelivr.net/npm/katex@0.16.9/dist/contrib/auto-render.min.js" onload="renderMathInElement(document.body);"></script>"#);
header.push_str("\n");
}
HtmlEngine::Raw => {
// No math rendering
}
}
// CSS styling
if self.css_styling {
header.push_str(" <style>\n");
header.push_str(&self.generate_css());
header.push_str(" </style>\n");
}
header.push_str("</head>\n");
header
}
/// Generate CSS styles
fn generate_css(&self) -> String {
let mut css = String::new();
css.push_str(" body {\n");
css.push_str(" font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;\n");
css.push_str(" line-height: 1.6;\n");
css.push_str(" max-width: 800px;\n");
css.push_str(" margin: 0 auto;\n");
css.push_str(" padding: 20px;\n");
css.push_str(" }\n");
// Theme colors
match self.theme {
HtmlTheme::Light => {
css.push_str(" body.theme-light {\n");
css.push_str(" background-color: #ffffff;\n");
css.push_str(" color: #333333;\n");
css.push_str(" }\n");
}
HtmlTheme::Dark => {
css.push_str(" body.theme-dark {\n");
css.push_str(" background-color: #1e1e1e;\n");
css.push_str(" color: #d4d4d4;\n");
css.push_str(" }\n");
}
HtmlTheme::Auto => {
css.push_str(" @media (prefers-color-scheme: dark) {\n");
css.push_str(" body { background-color: #1e1e1e; color: #d4d4d4; }\n");
css.push_str(" }\n");
}
}
css.push_str(" .content { padding: 20px; }\n");
css.push_str(" .math-display { text-align: center; margin: 20px 0; }\n");
css.push_str(" .math-inline { display: inline; }\n");
css.push_str(" .equation-block { margin: 15px 0; padding: 10px; background: #f5f5f5; border-radius: 4px; }\n");
css.push_str(" table { border-collapse: collapse; width: 100%; margin: 20px 0; }\n");
css.push_str(
" th, td { border: 1px solid #ddd; padding: 8px; text-align: left; }\n",
);
css.push_str(" th { background-color: #f2f2f2; }\n");
if self.accessibility {
css.push_str(" .sr-only { position: absolute; width: 1px; height: 1px; padding: 0; margin: -1px; overflow: hidden; clip: rect(0,0,0,0); border: 0; }\n");
}
css
}
/// Format plain text to HTML
fn format_text(&self, text: &str) -> String {
let escaped = self.escape_html(text);
// Convert math delimiters if present
let mut html = escaped;
// Display math $$...$$
html = html.replace("$$", "<div class=\"math-display\">$$");
html = html.replace("$$", "$$</div>");
// Inline math $...$
// This is simplistic - a real implementation would need proper parsing
format!("<p>{}</p>", html)
}
/// Format line data to HTML
fn format_lines(&self, lines: &[LineData]) -> String {
let mut html = String::new();
for line in lines {
match line.line_type.as_str() {
"text" => {
html.push_str("<p>");
html.push_str(&self.escape_html(&line.text));
html.push_str("</p>\n");
}
"math" | "equation" => {
let latex = line.latex.as_ref().unwrap_or(&line.text);
html.push_str(r#"<div class="math-display">"#);
if self.accessibility {
html.push_str(&format!(
r#"<span class="sr-only">Equation: {}</span>"#,
self.escape_html(&line.text)
));
}
html.push_str(&format!("$${}$$", latex));
html.push_str("</div>\n");
}
"inline_math" => {
let latex = line.latex.as_ref().unwrap_or(&line.text);
html.push_str(&format!(r#"<span class="math-inline">${}$</span>"#, latex));
}
"heading" => {
html.push_str(&format!("<h2>{}</h2>\n", self.escape_html(&line.text)));
}
"table" => {
html.push_str(&self.format_table(&line.text));
}
"image" => {
html.push_str(&format!(
r#"<img src="{}" alt="Image" loading="lazy">"#,
self.escape_html(&line.text)
));
html.push_str("\n");
}
_ => {
html.push_str("<p>");
html.push_str(&self.escape_html(&line.text));
html.push_str("</p>\n");
}
}
}
html
}
/// Format table to HTML
fn format_table(&self, table: &str) -> String {
let mut html = String::from("<table>\n");
let rows: Vec<&str> = table.lines().collect();
for (i, row) in rows.iter().enumerate() {
html.push_str(" <tr>\n");
let cells: Vec<&str> = row
.split('|')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect();
let tag = if i == 0 { "th" } else { "td" };
for cell in cells {
html.push_str(&format!(
" <{}>{}</{}>\n",
tag,
self.escape_html(cell),
tag
));
}
html.push_str(" </tr>\n");
}
html.push_str("</table>\n");
html
}
/// Escape HTML special characters
fn escape_html(&self, text: &str) -> String {
text.replace('&', "&amp;")
.replace('<', "&lt;")
.replace('>', "&gt;")
.replace('"', "&quot;")
.replace('\'', "&#39;")
}
}
impl Default for HtmlFormatter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::output::BoundingBox;
#[test]
fn test_html_header() {
let formatter = HtmlFormatter::new().with_engine(HtmlEngine::MathJax);
let header = formatter.html_header();
assert!(header.contains("<!DOCTYPE html>"));
assert!(header.contains("MathJax"));
}
#[test]
fn test_katex_header() {
let formatter = HtmlFormatter::new().with_engine(HtmlEngine::KaTeX);
let header = formatter.html_header();
assert!(header.contains("katex"));
}
#[test]
fn test_escape_html() {
let formatter = HtmlFormatter::new();
let result = formatter.escape_html("<script>alert('test')</script>");
assert!(result.contains("&lt;"));
assert!(result.contains("&gt;"));
assert!(!result.contains("<script>"));
}
#[test]
fn test_format_lines() {
let formatter = HtmlFormatter::new();
let lines = vec![
LineData {
line_type: "text".to_string(),
text: "Introduction".to_string(),
latex: None,
bbox: BoundingBox::new(0.0, 0.0, 100.0, 20.0),
confidence: 0.95,
words: None,
},
LineData {
line_type: "equation".to_string(),
text: "E = mc^2".to_string(),
latex: Some(r"E = mc^2".to_string()),
bbox: BoundingBox::new(0.0, 25.0, 100.0, 30.0),
confidence: 0.98,
words: None,
},
];
let result = formatter.format_lines(&lines);
assert!(result.contains("<p>Introduction</p>"));
assert!(result.contains("math-display"));
assert!(result.contains("$$"));
}
#[test]
fn test_dark_theme() {
let formatter = HtmlFormatter::new().theme(HtmlTheme::Dark);
let css = formatter.generate_css();
assert!(css.contains("theme-dark"));
assert!(css.contains("#1e1e1e"));
}
#[test]
fn test_accessibility() {
let formatter = HtmlFormatter::new().accessibility(true);
let lines = vec![LineData {
line_type: "equation".to_string(),
text: "x squared".to_string(),
latex: Some("x^2".to_string()),
bbox: BoundingBox::new(0.0, 0.0, 100.0, 20.0),
confidence: 0.98,
words: None,
}];
let result = formatter.format_lines(&lines);
assert!(result.contains("sr-only"));
assert!(result.contains("Equation:"));
}
}

View File

@@ -0,0 +1,354 @@
//! JSON API response formatter matching Scipix API specification
use super::{FormatsData, LineData, OcrResult};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
/// Complete API response matching Scipix format
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiResponse {
/// Request identifier
pub request_id: String,
/// API version
pub version: String,
/// Image information
pub image_width: u32,
pub image_height: u32,
/// Detection metadata
pub is_printed: bool,
pub is_handwritten: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub auto_rotate_confidence: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub auto_rotate_degrees: Option<i32>,
/// Confidence metrics
pub confidence: f32,
pub confidence_rate: f32,
/// Available output formats
#[serde(flatten)]
pub formats: FormatsData,
/// Detailed line data
#[serde(skip_serializing_if = "Option::is_none")]
pub line_data: Option<Vec<LineData>>,
/// Error information
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error_info: Option<ErrorInfo>,
/// Processing metadata
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, Value>>,
}
/// Error information structure
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorInfo {
pub code: String,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<Value>,
}
impl ApiResponse {
/// Create response from OCR result
pub fn from_ocr_result(result: OcrResult) -> Self {
Self {
request_id: result.request_id,
version: result.version,
image_width: result.image_width,
image_height: result.image_height,
is_printed: result.is_printed,
is_handwritten: result.is_handwritten,
auto_rotate_confidence: Some(result.auto_rotate_confidence),
auto_rotate_degrees: Some(result.auto_rotate_degrees),
confidence: result.confidence,
confidence_rate: result.confidence_rate,
formats: result.formats,
line_data: result.line_data,
error: result.error,
error_info: None,
metadata: if result.metadata.is_empty() {
None
} else {
Some(result.metadata)
},
}
}
/// Create error response
pub fn error(request_id: String, code: &str, message: &str) -> Self {
Self {
request_id,
version: "3.0".to_string(),
image_width: 0,
image_height: 0,
is_printed: false,
is_handwritten: false,
auto_rotate_confidence: None,
auto_rotate_degrees: None,
confidence: 0.0,
confidence_rate: 0.0,
formats: FormatsData::default(),
line_data: None,
error: Some(message.to_string()),
error_info: Some(ErrorInfo {
code: code.to_string(),
message: message.to_string(),
details: None,
}),
metadata: None,
}
}
/// Convert to JSON string
pub fn to_json(&self) -> Result<String, String> {
serde_json::to_string(self).map_err(|e| format!("JSON serialization error: {}", e))
}
/// Convert to pretty JSON string
pub fn to_json_pretty(&self) -> Result<String, String> {
serde_json::to_string_pretty(self).map_err(|e| format!("JSON serialization error: {}", e))
}
/// Parse from JSON string
pub fn from_json(json: &str) -> Result<Self, String> {
serde_json::from_str(json).map_err(|e| format!("JSON parsing error: {}", e))
}
}
/// Batch API response
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchApiResponse {
pub batch_id: String,
pub total: usize,
pub completed: usize,
pub results: Vec<ApiResponse>,
#[serde(skip_serializing_if = "Option::is_none")]
pub errors: Option<Vec<BatchError>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchError {
pub index: usize,
pub error: ErrorInfo,
}
impl BatchApiResponse {
pub fn new(batch_id: String, results: Vec<ApiResponse>) -> Self {
let total = results.len();
let completed = results.iter().filter(|r| r.error.is_none()).count();
let errors: Vec<BatchError> = results
.iter()
.enumerate()
.filter_map(|(i, r)| {
r.error_info.as_ref().map(|e| BatchError {
index: i,
error: e.clone(),
})
})
.collect();
Self {
batch_id,
total,
completed,
results,
errors: if errors.is_empty() {
None
} else {
Some(errors)
},
}
}
pub fn to_json(&self) -> Result<String, String> {
serde_json::to_string(self).map_err(|e| format!("JSON serialization error: {}", e))
}
pub fn to_json_pretty(&self) -> Result<String, String> {
serde_json::to_string_pretty(self).map_err(|e| format!("JSON serialization error: {}", e))
}
}
/// API request format
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ApiRequest {
/// Image source (URL or base64)
pub src: String,
/// Requested output formats
#[serde(skip_serializing_if = "Option::is_none")]
pub formats: Option<Vec<String>>,
/// OCR options
#[serde(skip_serializing_if = "Option::is_none")]
pub ocr: Option<OcrOptions>,
/// Additional metadata
#[serde(flatten)]
pub metadata: HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub math_inline_delimiters: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub math_display_delimiters: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub rm_spaces: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub rm_fonts: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub numbers_default_to_math: Option<bool>,
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_result() -> OcrResult {
OcrResult {
request_id: "test_123".to_string(),
version: "3.0".to_string(),
image_width: 800,
image_height: 600,
is_printed: true,
is_handwritten: false,
auto_rotate_confidence: 0.95,
auto_rotate_degrees: 0,
confidence: 0.98,
confidence_rate: 0.97,
formats: FormatsData {
text: Some("E = mc^2".to_string()),
latex_normal: Some(r"E = mc^2".to_string()),
..Default::default()
},
line_data: None,
error: None,
metadata: HashMap::new(),
}
}
#[test]
fn test_api_response_from_result() {
let result = create_test_result();
let response = ApiResponse::from_ocr_result(result);
assert_eq!(response.request_id, "test_123");
assert_eq!(response.version, "3.0");
assert_eq!(response.confidence, 0.98);
assert!(response.formats.text.is_some());
}
#[test]
fn test_api_response_to_json() {
let result = create_test_result();
let response = ApiResponse::from_ocr_result(result);
let json = response.to_json().unwrap();
assert!(json.contains("request_id"));
assert!(json.contains("test_123"));
assert!(json.contains("confidence"));
}
#[test]
fn test_api_response_round_trip() {
let result = create_test_result();
let response = ApiResponse::from_ocr_result(result);
let json = response.to_json().unwrap();
let parsed = ApiResponse::from_json(&json).unwrap();
assert_eq!(response.request_id, parsed.request_id);
assert_eq!(response.confidence, parsed.confidence);
}
#[test]
fn test_error_response() {
let response = ApiResponse::error(
"test_456".to_string(),
"invalid_image",
"Image format not supported",
);
assert_eq!(response.request_id, "test_456");
assert!(response.error.is_some());
assert!(response.error_info.is_some());
let error_info = response.error_info.unwrap();
assert_eq!(error_info.code, "invalid_image");
}
#[test]
fn test_batch_response() {
let result1 = create_test_result();
let result2 = create_test_result();
let responses = vec![
ApiResponse::from_ocr_result(result1),
ApiResponse::from_ocr_result(result2),
];
let batch = BatchApiResponse::new("batch_789".to_string(), responses);
assert_eq!(batch.batch_id, "batch_789");
assert_eq!(batch.total, 2);
assert_eq!(batch.completed, 2);
assert!(batch.errors.is_none());
}
#[test]
fn test_batch_with_errors() {
let success = create_test_result();
let error_response =
ApiResponse::error("fail_1".to_string(), "timeout", "Processing timeout");
let responses = vec![ApiResponse::from_ocr_result(success), error_response];
let batch = BatchApiResponse::new("batch_error".to_string(), responses);
assert_eq!(batch.total, 2);
assert_eq!(batch.completed, 1);
assert!(batch.errors.is_some());
assert_eq!(batch.errors.unwrap().len(), 1);
}
#[test]
fn test_api_request() {
let request = ApiRequest {
src: "https://example.com/image.png".to_string(),
formats: Some(vec!["text".to_string(), "latex_styled".to_string()]),
ocr: Some(OcrOptions {
math_inline_delimiters: Some(vec!["$".to_string(), "$".to_string()]),
math_display_delimiters: Some(vec!["$$".to_string(), "$$".to_string()]),
rm_spaces: Some(true),
rm_fonts: None,
numbers_default_to_math: Some(false),
}),
metadata: HashMap::new(),
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("src"));
assert!(json.contains("formats"));
}
}

View File

@@ -0,0 +1,430 @@
//! LaTeX output formatter with styling and package management
use super::LineData;
/// LaTeX document formatter
#[derive(Clone)]
pub struct LaTeXFormatter {
packages: Vec<String>,
document_class: String,
preamble: String,
numbered_equations: bool,
custom_delimiters: Option<(String, String)>,
}
impl LaTeXFormatter {
pub fn new() -> Self {
Self {
packages: vec!["amsmath".to_string(), "amssymb".to_string()],
document_class: "article".to_string(),
preamble: String::new(),
numbered_equations: false,
custom_delimiters: None,
}
}
pub fn with_packages(mut self, packages: Vec<String>) -> Self {
self.packages = packages;
self
}
pub fn add_package(mut self, package: String) -> Self {
if !self.packages.contains(&package) {
self.packages.push(package);
}
self
}
pub fn document_class(mut self, class: String) -> Self {
self.document_class = class;
self
}
pub fn preamble(mut self, preamble: String) -> Self {
self.preamble = preamble;
self
}
pub fn numbered_equations(mut self, numbered: bool) -> Self {
self.numbered_equations = numbered;
self
}
pub fn custom_delimiters(mut self, start: String, end: String) -> Self {
self.custom_delimiters = Some((start, end));
self
}
/// Format plain LaTeX content
pub fn format(&self, latex: &str) -> String {
// Clean up LaTeX if needed
let cleaned = self.clean_latex(latex);
// Apply custom delimiters if specified
if let Some((start, end)) = &self.custom_delimiters {
format!("{}{}{}", start, cleaned, end)
} else {
cleaned
}
}
/// Format line data to LaTeX
pub fn format_lines(&self, lines: &[LineData]) -> String {
let mut output = String::new();
let mut in_align = false;
for line in lines {
match line.line_type.as_str() {
"text" => {
if in_align {
output.push_str("\\end{align*}\n\n");
in_align = false;
}
output.push_str(&self.escape_text(&line.text));
output.push_str("\n\n");
}
"math" | "equation" => {
let latex = line.latex.as_ref().unwrap_or(&line.text);
if self.numbered_equations {
output.push_str("\\begin{equation}\n");
output.push_str(latex.trim());
output.push_str("\n\\end{equation}\n\n");
} else {
output.push_str("\\[\n");
output.push_str(latex.trim());
output.push_str("\n\\]\n\n");
}
}
"inline_math" => {
let latex = line.latex.as_ref().unwrap_or(&line.text);
output.push_str(&format!("${}$", latex.trim()));
}
"align" => {
if !in_align {
output.push_str("\\begin{align*}\n");
in_align = true;
}
let latex = line.latex.as_ref().unwrap_or(&line.text);
output.push_str(latex.trim());
output.push_str(" \\\\\n");
}
"table" => {
output.push_str(&self.format_table(&line.text));
output.push_str("\n\n");
}
_ => {
output.push_str(&line.text);
output.push_str("\n");
}
}
}
if in_align {
output.push_str("\\end{align*}\n");
}
output.trim().to_string()
}
/// Format complete LaTeX document
pub fn format_document(&self, content: &str) -> String {
let mut doc = String::new();
// Document class
doc.push_str(&format!("\\documentclass{{{}}}\n\n", self.document_class));
// Packages
for package in &self.packages {
doc.push_str(&format!("\\usepackage{{{}}}\n", package));
}
doc.push_str("\n");
// Custom preamble
if !self.preamble.is_empty() {
doc.push_str(&self.preamble);
doc.push_str("\n\n");
}
// Begin document
doc.push_str("\\begin{document}\n\n");
// Content
doc.push_str(content);
doc.push_str("\n\n");
// End document
doc.push_str("\\end{document}\n");
doc
}
/// Clean and normalize LaTeX
fn clean_latex(&self, latex: &str) -> String {
let mut cleaned = latex.to_string();
// Remove excessive whitespace
while cleaned.contains(" ") {
cleaned = cleaned.replace(" ", " ");
}
// Normalize line breaks
cleaned = cleaned.replace("\r\n", "\n");
// Ensure proper spacing around operators
for op in &["=", "+", "-", r"\times", r"\div"] {
let spaced = format!(" {} ", op);
cleaned = cleaned.replace(op, &spaced);
}
// Remove duplicate spaces again
while cleaned.contains(" ") {
cleaned = cleaned.replace(" ", " ");
}
cleaned.trim().to_string()
}
/// Escape special LaTeX characters in text
fn escape_text(&self, text: &str) -> String {
text.replace('\\', r"\\")
.replace('{', r"\{")
.replace('}', r"\}")
.replace('$', r"\$")
.replace('%', r"\%")
.replace('_', r"\_")
.replace('&', r"\&")
.replace('#', r"\#")
.replace('^', r"\^")
.replace('~', r"\~")
}
/// Format table to LaTeX tabular environment
fn format_table(&self, table: &str) -> String {
let rows: Vec<&str> = table.lines().collect();
if rows.is_empty() {
return String::new();
}
// Determine number of columns from first row
let num_cols = rows[0].split('|').filter(|s| !s.is_empty()).count();
let col_spec = "c".repeat(num_cols);
let mut output = format!("\\begin{{tabular}}{{{}}}\n", col_spec);
output.push_str("\\hline\n");
for (i, row) in rows.iter().enumerate() {
let cells: Vec<&str> = row
.split('|')
.map(|s| s.trim())
.filter(|s| !s.is_empty())
.collect();
output.push_str(&cells.join(" & "));
output.push_str(" \\\\\n");
if i == 0 {
output.push_str("\\hline\n");
}
}
output.push_str("\\hline\n");
output.push_str("\\end{tabular}");
output
}
/// Convert inline LaTeX to display math
pub fn inline_to_display(&self, latex: &str) -> String {
if self.numbered_equations {
format!("\\begin{{equation}}\n{}\n\\end{{equation}}", latex.trim())
} else {
format!("\\[\n{}\n\\]", latex.trim())
}
}
/// Add equation label
pub fn add_label(&self, latex: &str, label: &str) -> String {
format!("{}\n\\label{{{}}}", latex.trim(), label)
}
}
impl Default for LaTeXFormatter {
fn default() -> Self {
Self::new()
}
}
/// Styled LaTeX formatter with predefined templates
#[allow(dead_code)]
pub struct StyledLaTeXFormatter {
base: LaTeXFormatter,
style: LaTeXStyle,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LaTeXStyle {
Article,
Report,
Book,
Beamer,
Minimal,
}
impl StyledLaTeXFormatter {
pub fn new(style: LaTeXStyle) -> Self {
let base = match style {
LaTeXStyle::Article => LaTeXFormatter::new()
.document_class("article".to_string())
.with_packages(vec![
"amsmath".to_string(),
"amssymb".to_string(),
"graphicx".to_string(),
"hyperref".to_string(),
]),
LaTeXStyle::Report => LaTeXFormatter::new()
.document_class("report".to_string())
.with_packages(vec![
"amsmath".to_string(),
"amssymb".to_string(),
"graphicx".to_string(),
"hyperref".to_string(),
"geometry".to_string(),
]),
LaTeXStyle::Book => LaTeXFormatter::new()
.document_class("book".to_string())
.with_packages(vec![
"amsmath".to_string(),
"amssymb".to_string(),
"graphicx".to_string(),
"hyperref".to_string(),
"geometry".to_string(),
"fancyhdr".to_string(),
]),
LaTeXStyle::Beamer => LaTeXFormatter::new()
.document_class("beamer".to_string())
.with_packages(vec![
"amsmath".to_string(),
"amssymb".to_string(),
"graphicx".to_string(),
]),
LaTeXStyle::Minimal => LaTeXFormatter::new()
.document_class("article".to_string())
.with_packages(vec!["amsmath".to_string()]),
};
Self { base, style }
}
pub fn format_document(
&self,
content: &str,
title: Option<&str>,
author: Option<&str>,
) -> String {
let mut preamble = String::new();
if let Some(t) = title {
preamble.push_str(&format!("\\title{{{}}}\n", t));
}
if let Some(a) = author {
preamble.push_str(&format!("\\author{{{}}}\n", a));
}
if title.is_some() || author.is_some() {
preamble.push_str("\\date{\\today}\n");
}
let formatter = self.base.clone().preamble(preamble);
let mut doc = formatter.format_document(content);
// Add maketitle after \begin{document} if we have title/author
if title.is_some() || author.is_some() {
doc = doc.replace(
"\\begin{document}\n\n",
"\\begin{document}\n\n\\maketitle\n\n",
);
}
doc
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::output::BoundingBox;
#[test]
fn test_format_simple() {
let formatter = LaTeXFormatter::new();
let result = formatter.format("E = mc^2");
assert!(result.contains("mc^2"));
}
#[test]
fn test_format_document() {
let formatter = LaTeXFormatter::new();
let doc = formatter.format_document("E = mc^2");
assert!(doc.contains(r"\documentclass{article}"));
assert!(doc.contains(r"\usepackage{amsmath}"));
assert!(doc.contains(r"\begin{document}"));
assert!(doc.contains("mc^2"));
assert!(doc.contains(r"\end{document}"));
}
#[test]
fn test_escape_text() {
let formatter = LaTeXFormatter::new();
let result = formatter.escape_text("Price: $100 & 50%");
assert!(result.contains(r"\$100"));
assert!(result.contains(r"\&"));
assert!(result.contains(r"\%"));
}
#[test]
fn test_inline_to_display() {
let formatter = LaTeXFormatter::new();
let result = formatter.inline_to_display("x^2 + y^2 = r^2");
assert!(result.contains(r"\["));
assert!(result.contains(r"\]"));
}
#[test]
fn test_styled_formatter() {
let formatter = StyledLaTeXFormatter::new(LaTeXStyle::Article);
let doc = formatter.format_document("Content", Some("My Title"), Some("Author Name"));
assert!(doc.contains(r"\title{My Title}"));
assert!(doc.contains(r"\author{Author Name}"));
assert!(doc.contains(r"\maketitle"));
}
#[test]
fn test_format_lines() {
let formatter = LaTeXFormatter::new();
let lines = vec![
LineData {
line_type: "text".to_string(),
text: "Introduction".to_string(),
latex: None,
bbox: BoundingBox::new(0.0, 0.0, 100.0, 20.0),
confidence: 0.95,
words: None,
},
LineData {
line_type: "equation".to_string(),
text: "E = mc^2".to_string(),
latex: Some(r"E = mc^2".to_string()),
bbox: BoundingBox::new(0.0, 25.0, 100.0, 30.0),
confidence: 0.98,
words: None,
},
];
let result = formatter.format_lines(&lines);
assert!(result.contains("Introduction"));
assert!(result.contains(r"\[") || result.contains(r"\begin{equation}"));
assert!(result.contains("mc^2"));
}
}

View File

@@ -0,0 +1,379 @@
//! Scipix Markdown (MMD) formatter
//!
//! MMD is an enhanced markdown format that supports:
//! - Inline and display math with LaTeX
//! - Tables with alignment
//! - Chemistry notation (SMILES)
//! - Image embedding
//! - Structured documents
use super::{LineData, MathDelimiters};
/// Scipix Markdown formatter
pub struct MmdFormatter {
delimiters: MathDelimiters,
include_metadata: bool,
preserve_structure: bool,
}
impl MmdFormatter {
pub fn new() -> Self {
Self {
delimiters: MathDelimiters::default(),
include_metadata: false,
preserve_structure: true,
}
}
pub fn with_delimiters(delimiters: MathDelimiters) -> Self {
Self {
delimiters,
include_metadata: false,
preserve_structure: true,
}
}
pub fn include_metadata(mut self, include: bool) -> Self {
self.include_metadata = include;
self
}
pub fn preserve_structure(mut self, preserve: bool) -> Self {
self.preserve_structure = preserve;
self
}
/// Format line data to MMD
pub fn format(&self, lines: &[LineData]) -> String {
let mut output = String::new();
let mut in_table = false;
let mut in_list = false;
for line in lines {
match line.line_type.as_str() {
"text" => {
if in_table {
output.push_str("\n");
in_table = false;
}
if in_list && !line.text.trim_start().starts_with(&['-', '*', '1']) {
output.push_str("\n");
in_list = false;
}
output.push_str(&line.text);
output.push_str("\n");
}
"math" | "equation" => {
let latex = line.latex.as_ref().unwrap_or(&line.text);
let formatted = self.format_math(latex, true); // display mode
output.push_str(&formatted);
output.push_str("\n\n");
}
"inline_math" => {
let latex = line.latex.as_ref().unwrap_or(&line.text);
let formatted = self.format_math(latex, false); // inline mode
output.push_str(&formatted);
}
"table_row" => {
if !in_table {
in_table = true;
}
output.push_str(&self.format_table_row(&line.text));
output.push_str("\n");
}
"list_item" => {
if !in_list {
in_list = true;
}
output.push_str(&line.text);
output.push_str("\n");
}
"heading" => {
output.push_str(&format!("# {}\n\n", line.text));
}
"image" => {
output.push_str(&self.format_image(&line.text));
output.push_str("\n\n");
}
"chemistry" => {
let smiles = line.text.trim();
output.push_str(&format!("```smiles\n{}\n```\n\n", smiles));
}
_ => {
// Unknown type, output as text
output.push_str(&line.text);
output.push_str("\n");
}
}
}
output.trim().to_string()
}
/// Format LaTeX math expression
pub fn format_math(&self, latex: &str, display: bool) -> String {
if display {
format!(
"{}\n{}\n{}",
self.delimiters.display_start,
latex.trim(),
self.delimiters.display_end
)
} else {
format!(
"{}{}{}",
self.delimiters.inline_start,
latex.trim(),
self.delimiters.inline_end
)
}
}
/// Format table row
fn format_table_row(&self, row: &str) -> String {
// Basic table formatting - split by | and rejoin
let cells: Vec<&str> = row.split('|').map(|s| s.trim()).collect();
format!("| {} |", cells.join(" | "))
}
/// Format image reference
fn format_image(&self, path: &str) -> String {
// Extract alt text and path if available
if path.contains('[') && path.contains(']') {
path.to_string()
} else {
format!("![Image]({})", path)
}
}
/// Convert plain text with embedded LaTeX to MMD
pub fn from_mixed_text(&self, text: &str) -> String {
let mut output = String::new();
let mut current = String::new();
let mut in_math = false;
let mut display_math = false;
let chars: Vec<char> = text.chars().collect();
let mut i = 0;
while i < chars.len() {
// Check for display math $$
if i + 1 < chars.len() && chars[i] == '$' && chars[i + 1] == '$' {
if in_math && display_math {
// End display math
output.push_str(&self.format_math(&current, true));
current.clear();
in_math = false;
display_math = false;
} else if !in_math {
// Start display math
if !current.is_empty() {
output.push_str(&current);
current.clear();
}
in_math = true;
display_math = true;
}
i += 2;
continue;
}
// Check for inline math $
if chars[i] == '$' && !display_math {
if in_math {
// End inline math
output.push_str(&self.format_math(&current, false));
current.clear();
in_math = false;
} else {
// Start inline math
if !current.is_empty() {
output.push_str(&current);
current.clear();
}
in_math = true;
}
i += 1;
continue;
}
current.push(chars[i]);
i += 1;
}
if !current.is_empty() {
output.push_str(&current);
}
output
}
/// Format a complete document with frontmatter
pub fn format_document(&self, title: &str, content: &str, metadata: Option<&str>) -> String {
let mut doc = String::new();
// Add frontmatter if metadata provided
if let Some(meta) = metadata {
doc.push_str("---\n");
doc.push_str(meta);
doc.push_str("\n---\n\n");
}
// Add title
doc.push_str(&format!("# {}\n\n", title));
// Add content
doc.push_str(content);
doc
}
}
impl Default for MmdFormatter {
fn default() -> Self {
Self::new()
}
}
/// Parse MMD back to structured data
pub struct MmdParser;
impl MmdParser {
pub fn new() -> Self {
Self
}
/// Parse MMD content and extract LaTeX expressions
pub fn extract_latex(&self, content: &str) -> Vec<(String, bool)> {
let mut expressions = Vec::new();
let mut current = String::new();
let mut in_math = false;
let mut display_math = false;
let chars: Vec<char> = content.chars().collect();
let mut i = 0;
while i < chars.len() {
if i + 1 < chars.len() && chars[i] == '$' && chars[i + 1] == '$' {
if in_math && display_math {
expressions.push((current.trim().to_string(), true));
current.clear();
in_math = false;
display_math = false;
} else if !in_math {
in_math = true;
display_math = true;
}
i += 2;
} else if chars[i] == '$' && !display_math {
if in_math {
expressions.push((current.trim().to_string(), false));
current.clear();
in_math = false;
} else {
in_math = true;
}
i += 1;
} else if in_math {
current.push(chars[i]);
i += 1;
} else {
i += 1;
}
}
expressions
}
}
impl Default for MmdParser {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::output::BoundingBox;
#[test]
fn test_format_inline_math() {
let formatter = MmdFormatter::new();
let result = formatter.format_math("E = mc^2", false);
assert_eq!(result, "$E = mc^2$");
}
#[test]
fn test_format_display_math() {
let formatter = MmdFormatter::new();
let result = formatter.format_math(r"\int_0^1 x^2 dx", true);
assert!(result.contains("$$"));
assert!(result.contains(r"\int_0^1 x^2 dx"));
}
#[test]
fn test_format_lines() {
let formatter = MmdFormatter::new();
let lines = vec![
LineData {
line_type: "text".to_string(),
text: "The equation".to_string(),
latex: None,
bbox: BoundingBox::new(0.0, 0.0, 100.0, 20.0),
confidence: 0.95,
words: None,
},
LineData {
line_type: "math".to_string(),
text: "E = mc^2".to_string(),
latex: Some(r"E = mc^2".to_string()),
bbox: BoundingBox::new(0.0, 25.0, 100.0, 30.0),
confidence: 0.98,
words: None,
},
];
let result = formatter.format(&lines);
assert!(result.contains("The equation"));
assert!(result.contains("$$"));
assert!(result.contains("mc^2"));
}
#[test]
fn test_from_mixed_text() {
let formatter = MmdFormatter::new();
let text = "The formula $E = mc^2$ is famous.";
let result = formatter.from_mixed_text(text);
assert!(result.contains("$E = mc^2$"));
assert!(result.contains("famous"));
}
#[test]
fn test_extract_latex() {
let parser = MmdParser::new();
let content = "Text with $inline$ and $$display$$ math.";
let expressions = parser.extract_latex(content);
assert_eq!(expressions.len(), 2);
assert_eq!(expressions[0].0, "inline");
assert!(!expressions[0].1); // inline
assert_eq!(expressions[1].0, "display");
assert!(expressions[1].1); // display
}
#[test]
fn test_format_document() {
let formatter = MmdFormatter::new();
let doc = formatter.format_document(
"My Document",
"Content here",
Some("author: Test\ndate: 2025-01-01"),
);
assert!(doc.contains("---"));
assert!(doc.contains("author: Test"));
assert!(doc.contains("# My Document"));
assert!(doc.contains("Content here"));
}
}

View File

@@ -0,0 +1,359 @@
//! Output formatting module for Scipix OCR results
//!
//! Supports multiple output formats:
//! - Text: Plain text extraction
//! - LaTeX: Mathematical notation
//! - Scipix Markdown (mmd): Enhanced markdown with math
//! - MathML: XML-based mathematical markup
//! - HTML: Web-ready output with math rendering
//! - SMILES: Chemical structure notation
//! - DOCX: Microsoft Word format (Office Math ML)
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub mod docx;
pub mod formatter;
pub mod html;
pub mod json;
pub mod latex;
pub mod mmd;
pub mod smiles;
pub use formatter::{HtmlEngine, MathDelimiters, OutputFormatter};
pub use json::ApiResponse;
/// Output format types supported by Scipix OCR
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum OutputFormat {
/// Plain text output
Text,
/// LaTeX mathematical notation
#[serde(rename = "latex_normal")]
LaTeX,
/// Styled LaTeX with custom packages
#[serde(rename = "latex_styled")]
LaTeXStyled,
/// Mathematical Markup Language
#[serde(rename = "mathml")]
MathML,
/// Scipix Markdown (enhanced markdown)
#[serde(rename = "mmd")]
Mmd,
/// ASCII Math notation
#[serde(rename = "asciimath")]
AsciiMath,
/// HTML with embedded math
Html,
/// Chemical structure notation
#[serde(rename = "smiles")]
Smiles,
/// Microsoft Word format
Docx,
}
impl OutputFormat {
/// Get the file extension for this format
pub fn extension(&self) -> &'static str {
match self {
OutputFormat::Text => "txt",
OutputFormat::LaTeX | OutputFormat::LaTeXStyled => "tex",
OutputFormat::MathML => "xml",
OutputFormat::Mmd => "mmd",
OutputFormat::AsciiMath => "txt",
OutputFormat::Html => "html",
OutputFormat::Smiles => "smi",
OutputFormat::Docx => "docx",
}
}
/// Get the MIME type for this format
pub fn mime_type(&self) -> &'static str {
match self {
OutputFormat::Text | OutputFormat::AsciiMath => "text/plain",
OutputFormat::LaTeX | OutputFormat::LaTeXStyled => "application/x-latex",
OutputFormat::MathML => "application/mathml+xml",
OutputFormat::Mmd => "text/markdown",
OutputFormat::Html => "text/html",
OutputFormat::Smiles => "chemical/x-daylight-smiles",
OutputFormat::Docx => {
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
}
}
}
}
/// Complete OCR result with all possible output formats
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OcrResult {
/// Request identifier
pub request_id: String,
/// Version of the OCR engine
pub version: String,
/// Image dimensions
pub image_width: u32,
pub image_height: u32,
/// Processing status
pub is_printed: bool,
pub is_handwritten: bool,
pub auto_rotate_confidence: f32,
pub auto_rotate_degrees: i32,
/// Confidence scores
pub confidence: f32,
pub confidence_rate: f32,
/// Available output formats
pub formats: FormatsData,
/// Detailed line and word data
#[serde(skip_serializing_if = "Option::is_none")]
pub line_data: Option<Vec<LineData>>,
/// Error information if processing failed
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
/// Processing metadata
#[serde(flatten)]
pub metadata: HashMap<String, serde_json::Value>,
}
/// All available output format data
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct FormatsData {
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub latex_normal: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub latex_styled: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub latex_simplified: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub mathml: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub asciimath: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub mmd: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub html: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub smiles: Option<String>,
}
/// Line-level OCR data with positioning
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LineData {
/// Line type: text, math, table, image, etc.
#[serde(rename = "type")]
pub line_type: String,
/// Content in various formats
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub latex: Option<String>,
/// Bounding box coordinates
pub bbox: BoundingBox,
/// Confidence score
pub confidence: f32,
/// Word-level data
#[serde(skip_serializing_if = "Option::is_none")]
pub words: Option<Vec<WordData>>,
}
/// Word-level OCR data
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WordData {
pub text: String,
pub bbox: BoundingBox,
pub confidence: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub latex: Option<String>,
}
/// Bounding box coordinates (x, y, width, height)
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct BoundingBox {
pub x: f32,
pub y: f32,
pub width: f32,
pub height: f32,
}
impl BoundingBox {
pub fn new(x: f32, y: f32, width: f32, height: f32) -> Self {
Self {
x,
y,
width,
height,
}
}
pub fn area(&self) -> f32 {
self.width * self.height
}
pub fn center(&self) -> (f32, f32) {
(self.x + self.width / 2.0, self.y + self.height / 2.0)
}
}
/// Convert between output formats
pub fn convert_format(
content: &str,
from: OutputFormat,
to: OutputFormat,
) -> Result<String, String> {
// Simple pass-through for same format
if from == to {
return Ok(content.to_string());
}
// Format-specific conversions
match (from, to) {
(OutputFormat::LaTeX, OutputFormat::Text) => {
// Strip LaTeX commands for plain text
Ok(strip_latex(content))
}
(OutputFormat::Mmd, OutputFormat::LaTeX) => {
// Extract LaTeX from markdown
Ok(extract_latex_from_mmd(content))
}
(OutputFormat::LaTeX, OutputFormat::Html) => {
// Wrap LaTeX in HTML with MathJax
Ok(format!(
r#"<!DOCTYPE html>
<html>
<head>
<script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
<script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
</head>
<body>
<p>\({}\)</p>
</body>
</html>"#,
content
))
}
_ => Err(format!(
"Conversion from {:?} to {:?} not supported",
from, to
)),
}
}
fn strip_latex(content: &str) -> String {
// Remove common LaTeX commands
let mut result = content.to_string();
// Remove math delimiters
result = result.replace("\\(", "").replace("\\)", "");
result = result.replace("\\[", "").replace("\\]", "");
result = result.replace("$$", "");
// Remove common commands but keep their content
for cmd in &["\\text", "\\mathrm", "\\mathbf", "\\mathit"] {
result = result.replace(&format!("{}{}", cmd, "{"), "");
}
result = result.replace("}", "");
// Remove standalone commands
for cmd in &["\\\\", "\\,", "\\;", "\\:", "\\!", "\\quad", "\\qquad"] {
result = result.replace(cmd, " ");
}
result.trim().to_string()
}
fn extract_latex_from_mmd(content: &str) -> String {
let mut latex_parts = Vec::new();
let mut in_math = false;
let mut current = String::new();
let chars: Vec<char> = content.chars().collect();
let mut i = 0;
while i < chars.len() {
if i + 1 < chars.len() && chars[i] == '$' && chars[i + 1] == '$' {
if in_math {
latex_parts.push(current.clone());
current.clear();
in_math = false;
} else {
in_math = true;
}
i += 2;
} else if chars[i] == '$' {
in_math = !in_math;
i += 1;
} else if in_math {
current.push(chars[i]);
i += 1;
} else {
i += 1;
}
}
latex_parts.join("\n\n")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_output_format_extension() {
assert_eq!(OutputFormat::Text.extension(), "txt");
assert_eq!(OutputFormat::LaTeX.extension(), "tex");
assert_eq!(OutputFormat::Html.extension(), "html");
assert_eq!(OutputFormat::Mmd.extension(), "mmd");
}
#[test]
fn test_output_format_mime_type() {
assert_eq!(OutputFormat::Text.mime_type(), "text/plain");
assert_eq!(OutputFormat::LaTeX.mime_type(), "application/x-latex");
assert_eq!(OutputFormat::Html.mime_type(), "text/html");
}
#[test]
fn test_bounding_box() {
let bbox = BoundingBox::new(10.0, 20.0, 100.0, 50.0);
assert_eq!(bbox.area(), 5000.0);
assert_eq!(bbox.center(), (60.0, 45.0));
}
#[test]
fn test_strip_latex() {
let input = r"\text{Hello } \mathbf{World}";
let output = strip_latex(input);
assert!(output.contains("Hello"));
assert!(output.contains("World"));
}
#[test]
fn test_convert_same_format() {
let content = "test content";
let result = convert_format(content, OutputFormat::Text, OutputFormat::Text).unwrap();
assert_eq!(result, content);
}
}

View File

@@ -0,0 +1,347 @@
//! SMILES (Simplified Molecular Input Line Entry System) generator
//!
//! Converts chemical structure representations to SMILES notation.
//! This is a simplified implementation - full chemistry support requires
//! dedicated chemistry libraries like RDKit or OpenBabel.
use super::OcrResult;
/// SMILES notation generator for chemical structures
pub struct SmilesGenerator {
canonical: bool,
include_stereochemistry: bool,
}
impl SmilesGenerator {
pub fn new() -> Self {
Self {
canonical: true,
include_stereochemistry: true,
}
}
pub fn canonical(mut self, canonical: bool) -> Self {
self.canonical = canonical;
self
}
pub fn stereochemistry(mut self, include: bool) -> Self {
self.include_stereochemistry = include;
self
}
/// Generate SMILES from OCR result
pub fn generate_from_result(&self, result: &OcrResult) -> Result<String, String> {
// Check if SMILES already available
if let Some(smiles) = &result.formats.smiles {
return Ok(smiles.clone());
}
// Check for chemistry-related content in line data
if let Some(line_data) = &result.line_data {
for line in line_data {
if line.line_type == "chemistry" || line.line_type == "molecule" {
return self.parse_chemical_notation(&line.text);
}
}
}
Err("No chemical structure data found".to_string())
}
/// Parse chemical notation to SMILES
/// This is a placeholder - real implementation needs chemistry parsing
fn parse_chemical_notation(&self, notation: &str) -> Result<String, String> {
// Check if already SMILES format
if self.is_smiles(notation) {
return Ok(notation.to_string());
}
// Try to parse common chemical formulas
if let Some(smiles) = self.simple_formula_to_smiles(notation) {
return Ok(smiles);
}
Err(format!("Cannot convert '{}' to SMILES", notation))
}
/// Check if string is already SMILES notation
fn is_smiles(&self, s: &str) -> bool {
// Basic SMILES characters
let smiles_chars = "CNOPSFClBrI[]()=#@+-0123456789cnops";
s.chars().all(|c| smiles_chars.contains(c))
}
/// Convert simple chemical formulas to SMILES
fn simple_formula_to_smiles(&self, formula: &str) -> Option<String> {
// Common chemical formulas
match formula.trim() {
"H2O" | "water" => Some("O".to_string()),
"CO2" | "carbon dioxide" => Some("O=C=O".to_string()),
"CH4" | "methane" => Some("C".to_string()),
"C2H6" | "ethane" => Some("CC".to_string()),
"C2H5OH" | "ethanol" => Some("CCO".to_string()),
"CH3COOH" | "acetic acid" => Some("CC(=O)O".to_string()),
"C6H6" | "benzene" => Some("c1ccccc1".to_string()),
"C6H12O6" | "glucose" => Some("OC[C@H]1OC(O)[C@H](O)[C@@H](O)[C@@H]1O".to_string()),
"NH3" | "ammonia" => Some("N".to_string()),
"H2SO4" | "sulfuric acid" => Some("OS(=O)(=O)O".to_string()),
"NaCl" | "sodium chloride" => Some("[Na+].[Cl-]".to_string()),
_ => None,
}
}
/// Validate SMILES notation
pub fn validate(&self, smiles: &str) -> Result<(), String> {
// Basic validation checks
// Check parentheses balance
let mut depth = 0;
for c in smiles.chars() {
match c {
'(' => depth += 1,
')' => {
depth -= 1;
if depth < 0 {
return Err("Unbalanced parentheses".to_string());
}
}
_ => {}
}
}
if depth != 0 {
return Err("Unbalanced parentheses".to_string());
}
// Check brackets balance
let mut depth = 0;
for c in smiles.chars() {
match c {
'[' => depth += 1,
']' => {
depth -= 1;
if depth < 0 {
return Err("Unbalanced brackets".to_string());
}
}
_ => {}
}
}
if depth != 0 {
return Err("Unbalanced brackets".to_string());
}
Ok(())
}
/// Convert SMILES to molecular formula
pub fn to_molecular_formula(&self, smiles: &str) -> Result<String, String> {
self.validate(smiles)?;
// Simplified formula extraction
// Real implementation would parse the SMILES properly
let mut counts: std::collections::HashMap<char, usize> = std::collections::HashMap::new();
for c in smiles.chars() {
if c.is_alphabetic() && c.is_uppercase() {
*counts.entry(c).or_insert(0) += 1;
}
}
let mut formula = String::new();
// Only use single-character elements for simplicity
for element in &['C', 'H', 'N', 'O', 'S', 'P', 'F'] {
if let Some(&count) = counts.get(element) {
formula.push(*element);
if count > 1 {
formula.push_str(&count.to_string());
}
}
}
if formula.is_empty() {
Err("Could not determine molecular formula".to_string())
} else {
Ok(formula)
}
}
/// Calculate molecular weight (approximate)
pub fn molecular_weight(&self, smiles: &str) -> Result<f32, String> {
self.validate(smiles)?;
// Simplified atomic weights
let weights: std::collections::HashMap<char, f32> = [
('C', 12.01),
('H', 1.008),
('N', 14.01),
('O', 16.00),
('S', 32.07),
('P', 30.97),
('F', 19.00),
]
.iter()
.cloned()
.collect();
let mut total_weight = 0.0;
for c in smiles.chars() {
if let Some(&weight) = weights.get(&c) {
total_weight += weight;
}
}
Ok(total_weight)
}
}
impl Default for SmilesGenerator {
fn default() -> Self {
Self::new()
}
}
/// SMILES parser for extracting structure information
pub struct SmilesParser;
impl SmilesParser {
pub fn new() -> Self {
Self
}
/// Count atoms in SMILES notation
pub fn count_atoms(&self, smiles: &str) -> std::collections::HashMap<String, usize> {
let mut counts = std::collections::HashMap::new();
let mut i = 0;
let chars: Vec<char> = smiles.chars().collect();
while i < chars.len() {
if chars[i].is_uppercase() {
let mut atom = String::from(chars[i]);
// Check for two-letter atoms (Cl, Br, etc.)
if i + 1 < chars.len() && chars[i + 1].is_lowercase() {
atom.push(chars[i + 1]);
i += 1;
}
*counts.entry(atom).or_insert(0) += 1;
}
i += 1;
}
counts
}
/// Extract ring information
pub fn find_rings(&self, smiles: &str) -> Vec<usize> {
let mut rings = Vec::new();
for (_i, c) in smiles.chars().enumerate() {
if c.is_numeric() {
if let Some(digit) = c.to_digit(10) {
rings.push(digit as usize);
}
}
}
rings
}
}
impl Default for SmilesParser {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_smiles() {
let gen = SmilesGenerator::new();
assert!(gen.is_smiles("CCO"));
assert!(gen.is_smiles("c1ccccc1"));
assert!(gen.is_smiles("CC(=O)O"));
assert!(!gen.is_smiles("not smiles!"));
}
#[test]
fn test_simple_formula_conversion() {
let gen = SmilesGenerator::new();
assert_eq!(gen.simple_formula_to_smiles("H2O"), Some("O".to_string()));
assert_eq!(
gen.simple_formula_to_smiles("CO2"),
Some("O=C=O".to_string())
);
assert_eq!(gen.simple_formula_to_smiles("CH4"), Some("C".to_string()));
assert_eq!(
gen.simple_formula_to_smiles("benzene"),
Some("c1ccccc1".to_string())
);
}
#[test]
fn test_validate_smiles() {
let gen = SmilesGenerator::new();
assert!(gen.validate("CCO").is_ok());
assert!(gen.validate("CC(O)C").is_ok());
assert!(gen.validate("c1ccccc1").is_ok());
assert!(gen.validate("CC(O").is_err()); // Unbalanced
assert!(gen.validate("CC)O").is_err()); // Unbalanced
}
#[test]
fn test_molecular_formula() {
let gen = SmilesGenerator::new();
let formula = gen.to_molecular_formula("CCO").unwrap();
assert!(formula.contains('C'));
assert!(formula.contains('O'));
}
#[test]
fn test_molecular_weight() {
let gen = SmilesGenerator::new();
// Water: H2O (but SMILES is just "O", representing OH2)
let weight = gen.molecular_weight("O").unwrap();
assert!(weight > 0.0);
// Ethanol: C2H6O
let weight = gen.molecular_weight("CCO").unwrap();
assert!(weight > 30.0); // Should be around 46
}
#[test]
fn test_count_atoms() {
let parser = SmilesParser::new();
let counts = parser.count_atoms("CCO");
assert_eq!(counts.get("C"), Some(&2));
assert_eq!(counts.get("O"), Some(&1));
let counts = parser.count_atoms("CC(=O)O");
assert_eq!(counts.get("C"), Some(&2));
assert_eq!(counts.get("O"), Some(&2));
}
#[test]
fn test_find_rings() {
let parser = SmilesParser::new();
let rings = parser.find_rings("c1ccccc1");
assert_eq!(rings, vec![1, 1]);
let rings = parser.find_rings("C1CC1");
assert_eq!(rings, vec![1, 1]);
}
}

View File

@@ -0,0 +1,353 @@
//! Skew detection and correction using Hough transform
use super::{PreprocessError, Result};
use image::{GrayImage, Luma};
use imageproc::edges::canny;
use imageproc::geometric_transformations::{rotate_about_center, Interpolation};
use std::collections::BTreeMap;
use std::f32;
/// Detect skew angle using Hough transform
///
/// Applies edge detection and Hough transform to find dominant lines,
/// then calculates average skew angle.
///
/// # Arguments
/// * `image` - Input grayscale image
///
/// # Returns
/// Skew angle in degrees (positive = clockwise)
///
/// # Example
/// ```no_run
/// use ruvector_scipix::preprocess::deskew::detect_skew_angle;
/// # use image::GrayImage;
/// # let image = GrayImage::new(100, 100);
/// let angle = detect_skew_angle(&image).unwrap();
/// println!("Detected skew: {:.2}°", angle);
/// ```
pub fn detect_skew_angle(image: &GrayImage) -> Result<f32> {
let (width, height) = image.dimensions();
if width < 20 || height < 20 {
return Err(PreprocessError::InvalidParameters(
"Image too small for skew detection".to_string(),
));
}
// Apply Canny edge detection
let edges = canny(image, 50.0, 100.0);
// Perform Hough transform to detect lines
let angles = detect_lines_hough(&edges, width, height)?;
if angles.is_empty() {
return Ok(0.0);
}
// Calculate weighted average angle
let total_weight: f32 = angles.values().sum();
let weighted_sum: f32 = angles
.iter()
.map(|(angle_key, weight)| (*angle_key as f32 / 10.0) * weight)
.sum();
let average_angle = if total_weight > 0.0 {
weighted_sum / total_weight
} else {
0.0
};
Ok(average_angle)
}
/// Detect lines using Hough transform
///
/// Returns map of angles to their confidence weights
fn detect_lines_hough(edges: &GrayImage, width: u32, height: u32) -> Result<BTreeMap<i32, f32>> {
let max_rho = ((width * width + height * height) as f32).sqrt() as usize;
let num_angles = 360;
// Accumulator array for Hough space
let mut accumulator = vec![vec![0u32; max_rho]; num_angles];
// Populate accumulator
for y in 0..height {
for x in 0..width {
if edges.get_pixel(x, y)[0] > 128 {
// Edge pixel found
for theta_idx in 0..num_angles {
let theta = (theta_idx as f32) * std::f32::consts::PI / 180.0;
let rho = (x as f32) * theta.cos() + (y as f32) * theta.sin();
let rho_idx = (rho + max_rho as f32 / 2.0) as usize;
if rho_idx < max_rho {
accumulator[theta_idx][rho_idx] += 1;
}
}
}
}
}
// Find peaks in accumulator
let mut angle_votes: BTreeMap<i32, f32> = BTreeMap::new();
let threshold = (width.min(height) / 10) as u32; // Adaptive threshold
for theta_idx in 0..num_angles {
for rho_idx in 0..max_rho {
let votes = accumulator[theta_idx][rho_idx];
if votes > threshold {
let angle = (theta_idx as f32) - 180.0; // Convert to -180 to 180
let normalized_angle = normalize_angle(angle);
// Only consider angles near horizontal (within ±45°)
if normalized_angle.abs() < 45.0 {
// Use integer keys for BTreeMap (angle * 10 to preserve precision)
let key = (normalized_angle * 10.0) as i32;
*angle_votes.entry(key).or_insert(0.0) += votes as f32;
}
}
}
}
Ok(angle_votes)
}
/// Normalize angle to -45 to +45 degree range
fn normalize_angle(angle: f32) -> f32 {
let mut normalized = angle % 180.0;
if normalized > 90.0 {
normalized -= 180.0;
} else if normalized < -90.0 {
normalized += 180.0;
}
// Clamp to ±45°
normalized.clamp(-45.0, 45.0)
}
/// Deskew image using detected skew angle
///
/// # Arguments
/// * `image` - Input grayscale image
/// * `angle` - Skew angle in degrees (from detect_skew_angle)
///
/// # Returns
/// Deskewed image with white background fill
///
/// # Example
/// ```no_run
/// use ruvector_scipix::preprocess::deskew::{detect_skew_angle, deskew_image};
/// # use image::GrayImage;
/// # let image = GrayImage::new(100, 100);
/// let angle = detect_skew_angle(&image).unwrap();
/// let deskewed = deskew_image(&image, angle).unwrap();
/// ```
pub fn deskew_image(image: &GrayImage, angle: f32) -> Result<GrayImage> {
if angle.abs() < 0.1 {
// No deskewing needed
return Ok(image.clone());
}
let radians = -angle.to_radians(); // Negate for correct direction
let deskewed = rotate_about_center(
image,
radians,
Interpolation::Bilinear,
Luma([255]), // White background
);
Ok(deskewed)
}
/// Auto-deskew image with confidence threshold
///
/// # Arguments
/// * `image` - Input grayscale image
/// * `max_angle` - Maximum angle to correct (degrees)
///
/// # Returns
/// Tuple of (deskewed_image, angle_applied)
pub fn auto_deskew(image: &GrayImage, max_angle: f32) -> Result<(GrayImage, f32)> {
let angle = detect_skew_angle(image)?;
if angle.abs() <= max_angle {
let deskewed = deskew_image(image, angle)?;
Ok((deskewed, angle))
} else {
// Angle too large, don't correct
Ok((image.clone(), 0.0))
}
}
/// Detect skew using projection profile method (alternative approach)
///
/// This is a faster but less accurate method compared to Hough transform
pub fn detect_skew_projection(image: &GrayImage) -> Result<f32> {
let angles = [
-45.0, -30.0, -15.0, -10.0, -5.0, 0.0, 5.0, 10.0, 15.0, 30.0, 45.0,
];
let mut max_variance = 0.0;
let mut best_angle = 0.0;
for &angle in &angles {
let variance = calculate_projection_variance(image, angle);
if variance > max_variance {
max_variance = variance;
best_angle = angle;
}
}
Ok(best_angle)
}
/// Calculate projection variance for a given angle
fn calculate_projection_variance(image: &GrayImage, angle: f32) -> f32 {
let (width, height) = image.dimensions();
let rad = angle.to_radians();
let cos_a = rad.cos();
let sin_a = rad.sin();
let mut projection = vec![0u32; height as usize];
for y in 0..height {
for x in 0..width {
let pixel = image.get_pixel(x, y)[0];
if pixel < 128 {
let proj_y = ((y as f32) * cos_a - (x as f32) * sin_a) as i32;
if proj_y >= 0 && proj_y < height as i32 {
projection[proj_y as usize] += 1;
}
}
}
}
// Calculate variance
if projection.is_empty() {
return 0.0;
}
let mean = projection.iter().sum::<u32>() as f32 / projection.len() as f32;
projection
.iter()
.map(|&x| {
let diff = x as f32 - mean;
diff * diff
})
.sum::<f32>()
/ projection.len() as f32
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_image() -> GrayImage {
let mut img = GrayImage::new(200, 100);
// Fill with white
for pixel in img.pixels_mut() {
*pixel = Luma([255]);
}
// Draw some horizontal lines (simulating text)
for y in [20, 40, 60, 80] {
for x in 10..190 {
img.put_pixel(x, y, Luma([0]));
}
}
img
}
#[test]
fn test_detect_skew_straight() {
let img = create_test_image();
let angle = detect_skew_angle(&img);
assert!(angle.is_ok());
let a = angle.unwrap();
// Should detect near-zero skew for straight lines
assert!(a.abs() < 10.0);
}
#[test]
fn test_deskew_image() {
let img = create_test_image();
// Deskew by 5 degrees
let deskewed = deskew_image(&img, 5.0);
assert!(deskewed.is_ok());
let result = deskewed.unwrap();
assert_eq!(result.dimensions(), img.dimensions());
}
#[test]
fn test_deskew_no_change() {
let img = create_test_image();
// Deskew by ~0 degrees
let deskewed = deskew_image(&img, 0.05);
assert!(deskewed.is_ok());
let result = deskewed.unwrap();
assert_eq!(result.dimensions(), img.dimensions());
}
#[test]
fn test_auto_deskew() {
let img = create_test_image();
let result = auto_deskew(&img, 15.0);
assert!(result.is_ok());
let (deskewed, angle) = result.unwrap();
assert_eq!(deskewed.dimensions(), img.dimensions());
assert!(angle.abs() <= 15.0);
}
#[test]
fn test_normalize_angle() {
assert!((normalize_angle(0.0) - 0.0).abs() < 0.01);
// Test normalization behavior
let angle_100 = normalize_angle(100.0);
assert!(angle_100.abs() <= 45.0); // Should be clamped to ±45°
let angle_neg100 = normalize_angle(-100.0);
assert!(angle_neg100.abs() <= 45.0); // Should be clamped to ±45°
assert!((normalize_angle(50.0) - 45.0).abs() < 0.01); // Clamped to 45
assert!((normalize_angle(-50.0) - -45.0).abs() < 0.01); // Clamped to -45
}
#[test]
fn test_detect_skew_projection() {
let img = create_test_image();
let angle = detect_skew_projection(&img);
assert!(angle.is_ok());
let a = angle.unwrap();
assert!(a.abs() < 20.0);
}
#[test]
fn test_skew_small_image_error() {
let small_img = GrayImage::new(10, 10);
let result = detect_skew_angle(&small_img);
assert!(result.is_err());
}
#[test]
fn test_projection_variance() {
let img = create_test_image();
let var_0 = calculate_projection_variance(&img, 0.0);
let var_30 = calculate_projection_variance(&img, 30.0);
// Variance at 0° should be higher for horizontal lines
assert!(var_0 > 0.0);
println!("Variance at 0°: {}, at 30°: {}", var_0, var_30);
}
}

View File

@@ -0,0 +1,420 @@
//! Image enhancement functions for improving OCR accuracy
use super::{PreprocessError, Result};
use image::{GrayImage, Luma};
use std::cmp;
/// Contrast Limited Adaptive Histogram Equalization (CLAHE)
///
/// Improves local contrast while avoiding over-amplification of noise.
/// Divides image into tiles and applies histogram equalization with clipping.
///
/// # Arguments
/// * `image` - Input grayscale image
/// * `clip_limit` - Contrast clipping limit (typically 2.0-4.0)
/// * `tile_size` - Size of contextual regions (typically 8x8 or 16x16)
///
/// # Returns
/// Enhanced image with improved local contrast
///
/// # Example
/// ```no_run
/// use ruvector_scipix::preprocess::enhancement::clahe;
/// # use image::GrayImage;
/// # let image = GrayImage::new(100, 100);
/// let enhanced = clahe(&image, 2.0, 8).unwrap();
/// ```
pub fn clahe(image: &GrayImage, clip_limit: f32, tile_size: u32) -> Result<GrayImage> {
if tile_size == 0 || clip_limit <= 0.0 {
return Err(PreprocessError::InvalidParameters(
"Invalid CLAHE parameters".to_string(),
));
}
let (width, height) = image.dimensions();
let mut result = GrayImage::new(width, height);
let tiles_x = (width + tile_size - 1) / tile_size;
let tiles_y = (height + tile_size - 1) / tile_size;
// Compute histograms and CDFs for each tile
let mut tile_cdfs = vec![vec![Vec::new(); tiles_x as usize]; tiles_y as usize];
for ty in 0..tiles_y {
for tx in 0..tiles_x {
let x_start = tx * tile_size;
let y_start = ty * tile_size;
let x_end = cmp::min(x_start + tile_size, width);
let y_end = cmp::min(y_start + tile_size, height);
let cdf = compute_tile_cdf(image, x_start, y_start, x_end, y_end, clip_limit);
tile_cdfs[ty as usize][tx as usize] = cdf;
}
}
// Interpolate and apply transformation
for y in 0..height {
for x in 0..width {
let pixel = image.get_pixel(x, y)[0];
// Find tile coordinates
let tx = (x as f32 / tile_size as f32).floor();
let ty = (y as f32 / tile_size as f32).floor();
// Calculate interpolation weights
let x_ratio = (x as f32 / tile_size as f32) - tx;
let y_ratio = (y as f32 / tile_size as f32) - ty;
let tx = tx as usize;
let ty = ty as usize;
// Bilinear interpolation between neighboring tiles
let value = if tx < tiles_x as usize - 1 && ty < tiles_y as usize - 1 {
let v00 = tile_cdfs[ty][tx][pixel as usize];
let v10 = tile_cdfs[ty][tx + 1][pixel as usize];
let v01 = tile_cdfs[ty + 1][tx][pixel as usize];
let v11 = tile_cdfs[ty + 1][tx + 1][pixel as usize];
let v0 = v00 * (1.0 - x_ratio) + v10 * x_ratio;
let v1 = v01 * (1.0 - x_ratio) + v11 * x_ratio;
v0 * (1.0 - y_ratio) + v1 * y_ratio
} else if tx < tiles_x as usize - 1 {
let v0 = tile_cdfs[ty][tx][pixel as usize];
let v1 = tile_cdfs[ty][tx + 1][pixel as usize];
v0 * (1.0 - x_ratio) + v1 * x_ratio
} else if ty < tiles_y as usize - 1 {
let v0 = tile_cdfs[ty][tx][pixel as usize];
let v1 = tile_cdfs[ty + 1][tx][pixel as usize];
v0 * (1.0 - y_ratio) + v1 * y_ratio
} else {
tile_cdfs[ty][tx][pixel as usize]
};
result.put_pixel(x, y, Luma([(value * 255.0) as u8]));
}
}
Ok(result)
}
/// Compute clipped histogram and CDF for a tile
fn compute_tile_cdf(
image: &GrayImage,
x_start: u32,
y_start: u32,
x_end: u32,
y_end: u32,
clip_limit: f32,
) -> Vec<f32> {
// Calculate histogram
let mut histogram = [0u32; 256];
let mut pixel_count = 0;
for y in y_start..y_end {
for x in x_start..x_end {
let pixel = image.get_pixel(x, y)[0];
histogram[pixel as usize] += 1;
pixel_count += 1;
}
}
if pixel_count == 0 {
return vec![0.0; 256];
}
// Apply contrast limiting
let clip_limit_actual = (clip_limit * pixel_count as f32 / 256.0) as u32;
let mut clipped_total = 0u32;
for h in histogram.iter_mut() {
if *h > clip_limit_actual {
clipped_total += *h - clip_limit_actual;
*h = clip_limit_actual;
}
}
// Redistribute clipped pixels
let redistribute = clipped_total / 256;
let remainder = clipped_total % 256;
for (i, h) in histogram.iter_mut().enumerate() {
*h += redistribute;
if i < remainder as usize {
*h += 1;
}
}
// Compute cumulative distribution function (CDF)
let mut cdf = vec![0.0; 256];
let mut cumsum = 0u32;
for (i, &h) in histogram.iter().enumerate() {
cumsum += h;
cdf[i] = cumsum as f32 / pixel_count as f32;
}
cdf
}
/// Normalize brightness across the image
///
/// Adjusts image to have mean brightness of 128
///
/// # Arguments
/// * `image` - Input grayscale image
///
/// # Returns
/// Brightness-normalized image
pub fn normalize_brightness(image: &GrayImage) -> GrayImage {
let (width, height) = image.dimensions();
let pixel_count = (width * height) as f32;
// Calculate mean brightness
let sum: u32 = image.pixels().map(|p| p[0] as u32).sum();
let mean = sum as f32 / pixel_count;
let target_mean = 128.0;
let adjustment = target_mean - mean;
// Apply adjustment
let mut result = GrayImage::new(width, height);
for (x, y, pixel) in image.enumerate_pixels() {
let adjusted = (pixel[0] as f32 + adjustment).clamp(0.0, 255.0) as u8;
result.put_pixel(x, y, Luma([adjusted]));
}
result
}
/// Remove shadows from document image
///
/// Uses morphological operations to estimate and subtract background
///
/// # Arguments
/// * `image` - Input grayscale image
///
/// # Returns
/// Image with reduced shadows
pub fn remove_shadows(image: &GrayImage) -> Result<GrayImage> {
let (width, height) = image.dimensions();
// Estimate background using dilation (morphological closing)
let kernel_size = (width.min(height) / 20).max(15) as usize;
let background = estimate_background(image, kernel_size);
// Subtract background
let mut result = GrayImage::new(width, height);
for (x, y, pixel) in image.enumerate_pixels() {
let bg = background.get_pixel(x, y)[0] as i32;
let fg = pixel[0] as i32;
// Normalize: (foreground / background) * 255
let normalized = if bg > 0 {
((fg as f32 / bg as f32) * 255.0).min(255.0) as u8
} else {
fg as u8
};
result.put_pixel(x, y, Luma([normalized]));
}
Ok(result)
}
/// Estimate background using max filter (dilation)
fn estimate_background(image: &GrayImage, kernel_size: usize) -> GrayImage {
let (width, height) = image.dimensions();
let mut background = GrayImage::new(width, height);
let half_kernel = (kernel_size / 2) as i32;
for y in 0..height {
for x in 0..width {
let mut max_val = 0u8;
// Find maximum in kernel window
for ky in -(half_kernel)..=half_kernel {
for kx in -(half_kernel)..=half_kernel {
let px = (x as i32 + kx).clamp(0, width as i32 - 1) as u32;
let py = (y as i32 + ky).clamp(0, height as i32 - 1) as u32;
let val = image.get_pixel(px, py)[0];
if val > max_val {
max_val = val;
}
}
}
background.put_pixel(x, y, Luma([max_val]));
}
}
background
}
/// Enhance contrast using simple linear stretch
///
/// Maps min-max range to 0-255
pub fn contrast_stretch(image: &GrayImage) -> GrayImage {
// Find min and max values
let mut min_val = 255u8;
let mut max_val = 0u8;
for pixel in image.pixels() {
let val = pixel[0];
if val < min_val {
min_val = val;
}
if val > max_val {
max_val = val;
}
}
if min_val == max_val {
return image.clone();
}
// Stretch contrast
let (width, height) = image.dimensions();
let mut result = GrayImage::new(width, height);
let range = (max_val - min_val) as f32;
for (x, y, pixel) in image.enumerate_pixels() {
let val = pixel[0];
let stretched = ((val - min_val) as f32 / range * 255.0) as u8;
result.put_pixel(x, y, Luma([stretched]));
}
result
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_image() -> GrayImage {
let mut img = GrayImage::new(100, 100);
for y in 0..100 {
for x in 0..100 {
let val = ((x + y) / 2) as u8;
img.put_pixel(x, y, Luma([val]));
}
}
img
}
#[test]
fn test_clahe() {
let img = create_test_image();
let enhanced = clahe(&img, 2.0, 8);
assert!(enhanced.is_ok());
let result = enhanced.unwrap();
assert_eq!(result.dimensions(), img.dimensions());
}
#[test]
fn test_clahe_invalid_params() {
let img = create_test_image();
// Invalid tile size
let result = clahe(&img, 2.0, 0);
assert!(result.is_err());
// Invalid clip limit
let result = clahe(&img, -1.0, 8);
assert!(result.is_err());
}
#[test]
fn test_normalize_brightness() {
let img = create_test_image();
let normalized = normalize_brightness(&img);
assert_eq!(normalized.dimensions(), img.dimensions());
// Check that mean is closer to 128
let sum: u32 = normalized.pixels().map(|p| p[0] as u32).sum();
let mean = sum as f32 / (100.0 * 100.0);
assert!((mean - 128.0).abs() < 5.0);
}
#[test]
fn test_remove_shadows() {
let img = create_test_image();
let result = remove_shadows(&img);
assert!(result.is_ok());
let shadow_removed = result.unwrap();
assert_eq!(shadow_removed.dimensions(), img.dimensions());
}
#[test]
fn test_contrast_stretch() {
// Create low contrast image
let mut img = GrayImage::new(100, 100);
for y in 0..100 {
for x in 0..100 {
let val = 100 + ((x + y) / 10) as u8; // Range: 100-119
img.put_pixel(x, y, Luma([val]));
}
}
let stretched = contrast_stretch(&img);
// Check that range is now 0-255
let mut min_val = 255u8;
let mut max_val = 0u8;
for pixel in stretched.pixels() {
let val = pixel[0];
if val < min_val {
min_val = val;
}
if val > max_val {
max_val = val;
}
}
assert_eq!(min_val, 0);
assert_eq!(max_val, 255);
}
#[test]
fn test_contrast_stretch_uniform() {
// Uniform image should remain unchanged
let mut img = GrayImage::new(50, 50);
for pixel in img.pixels_mut() {
*pixel = Luma([128]);
}
let stretched = contrast_stretch(&img);
for pixel in stretched.pixels() {
assert_eq!(pixel[0], 128);
}
}
#[test]
fn test_estimate_background() {
let img = create_test_image();
let background = estimate_background(&img, 5);
assert_eq!(background.dimensions(), img.dimensions());
// Background should have higher values (max filter)
for (orig, bg) in img.pixels().zip(background.pixels()) {
assert!(bg[0] >= orig[0]);
}
}
#[test]
fn test_clahe_various_tile_sizes() {
let img = create_test_image();
for tile_size in [4, 8, 16, 32] {
let result = clahe(&img, 2.0, tile_size);
assert!(result.is_ok());
}
}
}

View File

@@ -0,0 +1,277 @@
//! Image preprocessing module for OCR pipeline
//!
//! This module provides comprehensive image preprocessing capabilities including:
//! - Image transformations (grayscale, blur, sharpen, threshold)
//! - Rotation detection and correction
//! - Skew correction (deskewing)
//! - Image enhancement (CLAHE, normalization)
//! - Text region segmentation
//! - Complete preprocessing pipeline with parallel processing
pub mod deskew;
pub mod enhancement;
pub mod pipeline;
pub mod rotation;
pub mod segmentation;
pub mod transforms;
use image::{DynamicImage, GrayImage};
use serde::{Deserialize, Serialize};
use thiserror::Error;
/// Preprocessing error types
#[derive(Error, Debug)]
pub enum PreprocessError {
#[error("Image loading error: {0}")]
ImageLoad(String),
#[error("Invalid parameters: {0}")]
InvalidParameters(String),
#[error("Processing error: {0}")]
Processing(String),
#[error("Segmentation error: {0}")]
Segmentation(String),
}
/// Result type for preprocessing operations
pub type Result<T> = std::result::Result<T, PreprocessError>;
/// Preprocessing options for configuring the pipeline
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PreprocessOptions {
/// Enable rotation detection and correction
pub auto_rotate: bool,
/// Enable skew detection and correction
pub auto_deskew: bool,
/// Enable contrast enhancement
pub enhance_contrast: bool,
/// Enable denoising
pub denoise: bool,
/// Binarization threshold (None for auto Otsu)
pub threshold: Option<u8>,
/// Enable adaptive thresholding
pub adaptive_threshold: bool,
/// Adaptive threshold window size
pub adaptive_window_size: u32,
/// Target image width (None to keep original)
pub target_width: Option<u32>,
/// Target image height (None to keep original)
pub target_height: Option<u32>,
/// Enable text region detection
pub detect_regions: bool,
/// Gaussian blur sigma for denoising
pub blur_sigma: f32,
/// CLAHE clip limit for contrast enhancement
pub clahe_clip_limit: f32,
/// CLAHE tile size
pub clahe_tile_size: u32,
}
impl Default for PreprocessOptions {
fn default() -> Self {
Self {
auto_rotate: true,
auto_deskew: true,
enhance_contrast: true,
denoise: true,
threshold: None,
adaptive_threshold: true,
adaptive_window_size: 15,
target_width: None,
target_height: None,
detect_regions: true,
blur_sigma: 1.0,
clahe_clip_limit: 2.0,
clahe_tile_size: 8,
}
}
}
/// Type of text region
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RegionType {
/// Regular text
Text,
/// Mathematical equation
Math,
/// Table
Table,
/// Figure/Image
Figure,
/// Unknown/Other
Unknown,
}
/// Detected text region with bounding box
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextRegion {
/// Region type
pub region_type: RegionType,
/// Bounding box (x, y, width, height)
pub bbox: (u32, u32, u32, u32),
/// Confidence score (0.0 to 1.0)
pub confidence: f32,
/// Average text height in pixels
pub text_height: f32,
/// Detected baseline angle in degrees
pub baseline_angle: f32,
}
/// Main preprocessing function with configurable options
///
/// # Arguments
/// * `image` - Input image to preprocess
/// * `options` - Preprocessing configuration options
///
/// # Returns
/// Preprocessed grayscale image ready for OCR
///
/// # Example
/// ```no_run
/// use image::open;
/// use ruvector_scipix::preprocess::{preprocess, PreprocessOptions};
///
/// let img = open("document.jpg").unwrap();
/// let options = PreprocessOptions::default();
/// let processed = preprocess(&img, &options).unwrap();
/// ```
pub fn preprocess(image: &DynamicImage, options: &PreprocessOptions) -> Result<GrayImage> {
pipeline::PreprocessPipeline::builder()
.auto_rotate(options.auto_rotate)
.auto_deskew(options.auto_deskew)
.enhance_contrast(options.enhance_contrast)
.denoise(options.denoise)
.blur_sigma(options.blur_sigma)
.clahe_clip_limit(options.clahe_clip_limit)
.clahe_tile_size(options.clahe_tile_size)
.threshold(options.threshold)
.adaptive_threshold(options.adaptive_threshold)
.adaptive_window_size(options.adaptive_window_size)
.target_size(options.target_width, options.target_height)
.build()
.process(image)
}
/// Detect text regions in an image
///
/// # Arguments
/// * `image` - Input grayscale image
/// * `min_region_size` - Minimum region size in pixels
///
/// # Returns
/// Vector of detected text regions with metadata
///
/// # Example
/// ```no_run
/// use image::open;
/// use ruvector_scipix::preprocess::detect_text_regions;
///
/// let img = open("document.jpg").unwrap().to_luma8();
/// let regions = detect_text_regions(&img, 100).unwrap();
/// println!("Found {} text regions", regions.len());
/// ```
pub fn detect_text_regions(image: &GrayImage, min_region_size: u32) -> Result<Vec<TextRegion>> {
segmentation::find_text_regions(image, min_region_size)
}
#[cfg(test)]
mod tests {
use super::*;
use image::{Rgb, RgbImage};
fn create_test_image(width: u32, height: u32) -> DynamicImage {
let mut img = RgbImage::new(width, height);
// Create a simple test pattern
for y in 0..height {
for x in 0..width {
let val = ((x + y) % 256) as u8;
img.put_pixel(x, y, Rgb([val, val, val]));
}
}
DynamicImage::ImageRgb8(img)
}
#[test]
fn test_preprocess_default_options() {
let img = create_test_image(100, 100);
let options = PreprocessOptions::default();
let result = preprocess(&img, &options);
assert!(result.is_ok());
let processed = result.unwrap();
assert_eq!(processed.width(), 100);
assert_eq!(processed.height(), 100);
}
#[test]
fn test_preprocess_with_resize() {
let img = create_test_image(200, 200);
let mut options = PreprocessOptions::default();
options.target_width = Some(100);
options.target_height = Some(100);
let result = preprocess(&img, &options);
assert!(result.is_ok());
let processed = result.unwrap();
assert_eq!(processed.width(), 100);
assert_eq!(processed.height(), 100);
}
#[test]
fn test_preprocess_options_builder() {
let options = PreprocessOptions {
auto_rotate: false,
auto_deskew: false,
enhance_contrast: true,
denoise: true,
threshold: Some(128),
adaptive_threshold: false,
..Default::default()
};
assert!(!options.auto_rotate);
assert!(!options.auto_deskew);
assert!(options.enhance_contrast);
assert_eq!(options.threshold, Some(128));
}
#[test]
fn test_region_type_serialization() {
let region = TextRegion {
region_type: RegionType::Math,
bbox: (10, 20, 100, 50),
confidence: 0.95,
text_height: 12.0,
baseline_angle: 0.5,
};
let json = serde_json::to_string(&region).unwrap();
let deserialized: TextRegion = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.region_type, RegionType::Math);
assert_eq!(deserialized.bbox, (10, 20, 100, 50));
assert!((deserialized.confidence - 0.95).abs() < 0.001);
}
}

View File

@@ -0,0 +1,456 @@
//! Complete preprocessing pipeline with builder pattern and parallel processing
use super::Result;
use crate::preprocess::{deskew, enhancement, rotation, transforms};
use image::{DynamicImage, GrayImage};
use rayon::prelude::*;
use std::sync::Arc;
/// Progress callback type
pub type ProgressCallback = Arc<dyn Fn(&str, f32) + Send + Sync>;
/// Complete preprocessing pipeline with configurable steps
pub struct PreprocessPipeline {
auto_rotate: bool,
auto_deskew: bool,
enhance_contrast: bool,
denoise: bool,
blur_sigma: f32,
clahe_clip_limit: f32,
clahe_tile_size: u32,
threshold: Option<u8>,
adaptive_threshold: bool,
adaptive_window_size: u32,
target_width: Option<u32>,
target_height: Option<u32>,
progress_callback: Option<ProgressCallback>,
}
/// Builder for preprocessing pipeline
pub struct PreprocessPipelineBuilder {
auto_rotate: bool,
auto_deskew: bool,
enhance_contrast: bool,
denoise: bool,
blur_sigma: f32,
clahe_clip_limit: f32,
clahe_tile_size: u32,
threshold: Option<u8>,
adaptive_threshold: bool,
adaptive_window_size: u32,
target_width: Option<u32>,
target_height: Option<u32>,
progress_callback: Option<ProgressCallback>,
}
impl Default for PreprocessPipelineBuilder {
fn default() -> Self {
Self {
auto_rotate: true,
auto_deskew: true,
enhance_contrast: true,
denoise: true,
blur_sigma: 1.0,
clahe_clip_limit: 2.0,
clahe_tile_size: 8,
threshold: None,
adaptive_threshold: true,
adaptive_window_size: 15,
target_width: None,
target_height: None,
progress_callback: None,
}
}
}
impl PreprocessPipelineBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn auto_rotate(mut self, enable: bool) -> Self {
self.auto_rotate = enable;
self
}
pub fn auto_deskew(mut self, enable: bool) -> Self {
self.auto_deskew = enable;
self
}
pub fn enhance_contrast(mut self, enable: bool) -> Self {
self.enhance_contrast = enable;
self
}
pub fn denoise(mut self, enable: bool) -> Self {
self.denoise = enable;
self
}
pub fn blur_sigma(mut self, sigma: f32) -> Self {
self.blur_sigma = sigma;
self
}
pub fn clahe_clip_limit(mut self, limit: f32) -> Self {
self.clahe_clip_limit = limit;
self
}
pub fn clahe_tile_size(mut self, size: u32) -> Self {
self.clahe_tile_size = size;
self
}
pub fn threshold(mut self, threshold: Option<u8>) -> Self {
self.threshold = threshold;
self
}
pub fn adaptive_threshold(mut self, enable: bool) -> Self {
self.adaptive_threshold = enable;
self
}
pub fn adaptive_window_size(mut self, size: u32) -> Self {
self.adaptive_window_size = size;
self
}
pub fn target_size(mut self, width: Option<u32>, height: Option<u32>) -> Self {
self.target_width = width;
self.target_height = height;
self
}
pub fn progress_callback<F>(mut self, callback: F) -> Self
where
F: Fn(&str, f32) + Send + Sync + 'static,
{
self.progress_callback = Some(Arc::new(callback));
self
}
pub fn build(self) -> PreprocessPipeline {
PreprocessPipeline {
auto_rotate: self.auto_rotate,
auto_deskew: self.auto_deskew,
enhance_contrast: self.enhance_contrast,
denoise: self.denoise,
blur_sigma: self.blur_sigma,
clahe_clip_limit: self.clahe_clip_limit,
clahe_tile_size: self.clahe_tile_size,
threshold: self.threshold,
adaptive_threshold: self.adaptive_threshold,
adaptive_window_size: self.adaptive_window_size,
target_width: self.target_width,
target_height: self.target_height,
progress_callback: self.progress_callback,
}
}
}
impl PreprocessPipeline {
/// Create a new pipeline builder
pub fn builder() -> PreprocessPipelineBuilder {
PreprocessPipelineBuilder::new()
}
/// Report progress if callback is set
fn report_progress(&self, step: &str, progress: f32) {
if let Some(callback) = &self.progress_callback {
callback(step, progress);
}
}
/// Process a single image through the complete pipeline
///
/// # Pipeline steps:
/// 1. Convert to grayscale
/// 2. Detect and correct rotation (if enabled)
/// 3. Detect and correct skew (if enabled)
/// 4. Enhance contrast with CLAHE (if enabled)
/// 5. Denoise with Gaussian blur (if enabled)
/// 6. Apply thresholding (binary or adaptive)
/// 7. Resize to target dimensions (if specified)
pub fn process(&self, image: &DynamicImage) -> Result<GrayImage> {
self.report_progress("Starting preprocessing", 0.0);
// Step 1: Convert to grayscale
self.report_progress("Converting to grayscale", 0.1);
let mut gray = transforms::to_grayscale(image);
// Step 2: Auto-rotate
if self.auto_rotate {
self.report_progress("Detecting rotation", 0.2);
let angle = rotation::detect_rotation(&gray)?;
if angle.abs() > 0.5 {
self.report_progress("Correcting rotation", 0.25);
gray = rotation::rotate_image(&gray, -angle)?;
}
}
// Step 3: Auto-deskew
if self.auto_deskew {
self.report_progress("Detecting skew", 0.3);
let angle = deskew::detect_skew_angle(&gray)?;
if angle.abs() > 0.5 {
self.report_progress("Correcting skew", 0.35);
gray = deskew::deskew_image(&gray, angle)?;
}
}
// Step 4: Enhance contrast
if self.enhance_contrast {
self.report_progress("Enhancing contrast", 0.5);
gray = enhancement::clahe(&gray, self.clahe_clip_limit, self.clahe_tile_size)?;
}
// Step 5: Denoise
if self.denoise {
self.report_progress("Denoising", 0.6);
gray = transforms::gaussian_blur(&gray, self.blur_sigma)?;
}
// Step 6: Thresholding
self.report_progress("Applying threshold", 0.7);
gray = if self.adaptive_threshold {
transforms::adaptive_threshold(&gray, self.adaptive_window_size)?
} else if let Some(threshold_val) = self.threshold {
transforms::threshold(&gray, threshold_val)
} else {
// Auto Otsu threshold
let threshold_val = transforms::otsu_threshold(&gray)?;
transforms::threshold(&gray, threshold_val)
};
// Step 7: Resize
if let (Some(width), Some(height)) = (self.target_width, self.target_height) {
self.report_progress("Resizing", 0.9);
gray = image::imageops::resize(
&gray,
width,
height,
image::imageops::FilterType::Lanczos3,
);
}
self.report_progress("Preprocessing complete", 1.0);
Ok(gray)
}
/// Process multiple images in parallel
///
/// # Arguments
/// * `images` - Vector of images to process
///
/// # Returns
/// Vector of preprocessed images in the same order
pub fn process_batch(&self, images: Vec<DynamicImage>) -> Result<Vec<GrayImage>> {
images
.into_par_iter()
.map(|img| self.process(&img))
.collect()
}
/// Process image and return intermediate results from each step
///
/// Useful for debugging and visualization
pub fn process_with_intermediates(
&self,
image: &DynamicImage,
) -> Result<Vec<(String, GrayImage)>> {
let mut results = Vec::new();
// Step 1: Grayscale
let mut gray = transforms::to_grayscale(image);
results.push(("01_grayscale".to_string(), gray.clone()));
// Step 2: Rotation
if self.auto_rotate {
let angle = rotation::detect_rotation(&gray)?;
if angle.abs() > 0.5 {
gray = rotation::rotate_image(&gray, -angle)?;
results.push(("02_rotated".to_string(), gray.clone()));
}
}
// Step 3: Deskew
if self.auto_deskew {
let angle = deskew::detect_skew_angle(&gray)?;
if angle.abs() > 0.5 {
gray = deskew::deskew_image(&gray, angle)?;
results.push(("03_deskewed".to_string(), gray.clone()));
}
}
// Step 4: Enhancement
if self.enhance_contrast {
gray = enhancement::clahe(&gray, self.clahe_clip_limit, self.clahe_tile_size)?;
results.push(("04_enhanced".to_string(), gray.clone()));
}
// Step 5: Denoise
if self.denoise {
gray = transforms::gaussian_blur(&gray, self.blur_sigma)?;
results.push(("05_denoised".to_string(), gray.clone()));
}
// Step 6: Threshold
gray = if self.adaptive_threshold {
transforms::adaptive_threshold(&gray, self.adaptive_window_size)?
} else if let Some(threshold_val) = self.threshold {
transforms::threshold(&gray, threshold_val)
} else {
let threshold_val = transforms::otsu_threshold(&gray)?;
transforms::threshold(&gray, threshold_val)
};
results.push(("06_thresholded".to_string(), gray.clone()));
// Step 7: Resize
if let (Some(width), Some(height)) = (self.target_width, self.target_height) {
gray = image::imageops::resize(
&gray,
width,
height,
image::imageops::FilterType::Lanczos3,
);
results.push(("07_resized".to_string(), gray.clone()));
}
Ok(results)
}
}
#[cfg(test)]
mod tests {
use super::*;
use image::{Rgb, RgbImage};
fn create_test_image() -> DynamicImage {
let mut img = RgbImage::new(100, 100);
for y in 0..100 {
for x in 0..100 {
let val = ((x + y) / 2) as u8;
img.put_pixel(x, y, Rgb([val, val, val]));
}
}
DynamicImage::ImageRgb8(img)
}
#[test]
fn test_pipeline_builder() {
let pipeline = PreprocessPipeline::builder()
.auto_rotate(false)
.denoise(true)
.blur_sigma(1.5)
.build();
assert!(!pipeline.auto_rotate);
assert!(pipeline.denoise);
assert!((pipeline.blur_sigma - 1.5).abs() < 0.001);
}
#[test]
fn test_pipeline_process() {
let img = create_test_image();
let pipeline = PreprocessPipeline::builder()
.auto_rotate(false)
.auto_deskew(false)
.build();
let result = pipeline.process(&img);
assert!(result.is_ok());
let processed = result.unwrap();
assert_eq!(processed.width(), 100);
assert_eq!(processed.height(), 100);
}
#[test]
fn test_pipeline_with_resize() {
let img = create_test_image();
let pipeline = PreprocessPipeline::builder()
.target_size(Some(50), Some(50))
.auto_rotate(false)
.auto_deskew(false)
.build();
let result = pipeline.process(&img);
assert!(result.is_ok());
let processed = result.unwrap();
assert_eq!(processed.width(), 50);
assert_eq!(processed.height(), 50);
}
#[test]
fn test_pipeline_batch_processing() {
let images = vec![
create_test_image(),
create_test_image(),
create_test_image(),
];
let pipeline = PreprocessPipeline::builder()
.auto_rotate(false)
.auto_deskew(false)
.build();
let results = pipeline.process_batch(images);
assert!(results.is_ok());
let processed = results.unwrap();
assert_eq!(processed.len(), 3);
}
#[test]
fn test_pipeline_intermediates() {
let img = create_test_image();
let pipeline = PreprocessPipeline::builder()
.auto_rotate(false)
.auto_deskew(false)
.enhance_contrast(true)
.denoise(true)
.build();
let result = pipeline.process_with_intermediates(&img);
assert!(result.is_ok());
let intermediates = result.unwrap();
assert!(!intermediates.is_empty());
assert!(intermediates
.iter()
.any(|(name, _)| name.contains("grayscale")));
assert!(intermediates
.iter()
.any(|(name, _)| name.contains("thresholded")));
}
#[test]
fn test_progress_callback() {
use std::sync::{Arc, Mutex};
let progress_steps = Arc::new(Mutex::new(Vec::new()));
let progress_clone = Arc::clone(&progress_steps);
let pipeline = PreprocessPipeline::builder()
.auto_rotate(false)
.auto_deskew(false)
.progress_callback(move |step, _progress| {
progress_clone.lock().unwrap().push(step.to_string());
})
.build();
let img = create_test_image();
let _ = pipeline.process(&img);
let steps = progress_steps.lock().unwrap();
assert!(!steps.is_empty());
assert!(steps.iter().any(|s| s.contains("Starting")));
assert!(steps.iter().any(|s| s.contains("complete")));
}
}

View File

@@ -0,0 +1,319 @@
//! Rotation detection and correction using projection profiles
use super::{PreprocessError, Result};
use image::{GrayImage, Luma};
use imageproc::geometric_transformations::{rotate_about_center, Interpolation};
use std::f32;
/// Detect rotation angle using projection profile analysis
///
/// Uses horizontal and vertical projection profiles to detect document rotation.
/// Returns angle in degrees (typically in range -45 to +45).
///
/// # Arguments
/// * `image` - Input grayscale image
///
/// # Returns
/// Rotation angle in degrees (positive = clockwise)
///
/// # Example
/// ```no_run
/// use ruvector_scipix::preprocess::rotation::detect_rotation;
/// # use image::GrayImage;
/// # let image = GrayImage::new(100, 100);
/// let angle = detect_rotation(&image).unwrap();
/// println!("Detected rotation: {:.2}°", angle);
/// ```
pub fn detect_rotation(image: &GrayImage) -> Result<f32> {
let (width, height) = image.dimensions();
if width < 10 || height < 10 {
return Err(PreprocessError::InvalidParameters(
"Image too small for rotation detection".to_string(),
));
}
// Calculate projection profiles for different angles
let angles = [-45.0, -30.0, -15.0, 0.0, 15.0, 30.0, 45.0];
let mut max_score = 0.0;
let mut best_angle = 0.0;
for &angle in &angles {
let score = calculate_projection_score(image, angle);
if score > max_score {
max_score = score;
best_angle = angle;
}
}
// Refine angle with finer search around best candidate
let fine_angles: Vec<f32> = (-5..=5).map(|i| best_angle + (i as f32) * 2.0).collect();
max_score = 0.0;
for angle in fine_angles {
let score = calculate_projection_score(image, angle);
if score > max_score {
max_score = score;
best_angle = angle;
}
}
Ok(best_angle)
}
/// Calculate projection profile score for a given rotation angle
///
/// Higher scores indicate better alignment with text baselines
fn calculate_projection_score(image: &GrayImage, angle: f32) -> f32 {
let (width, height) = image.dimensions();
// For 0 degrees, use direct projection
if angle.abs() < 0.1 {
return calculate_horizontal_projection_variance(image);
}
// For non-zero angles, calculate projection along rotated axis
let rad = angle.to_radians();
let cos_a = rad.cos();
let sin_a = rad.sin();
let mut projection = vec![0u32; height as usize];
for y in 0..height {
for x in 0..width {
let pixel = image.get_pixel(x, y)[0];
if pixel < 128 {
// Project pixel onto rotated horizontal axis
let proj_y = ((y as f32) * cos_a - (x as f32) * sin_a) as i32;
if proj_y >= 0 && proj_y < height as i32 {
projection[proj_y as usize] += 1;
}
}
}
}
// Calculate variance of projection (higher = better alignment)
calculate_variance(&projection)
}
/// Calculate horizontal projection variance
fn calculate_horizontal_projection_variance(image: &GrayImage) -> f32 {
let (width, height) = image.dimensions();
let mut projection = vec![0u32; height as usize];
for y in 0..height {
for x in 0..width {
let pixel = image.get_pixel(x, y)[0];
if pixel < 128 {
projection[y as usize] += 1;
}
}
}
calculate_variance(&projection)
}
/// Calculate variance of projection profile
fn calculate_variance(projection: &[u32]) -> f32 {
if projection.is_empty() {
return 0.0;
}
let mean = projection.iter().sum::<u32>() as f32 / projection.len() as f32;
let variance = projection
.iter()
.map(|&x| {
let diff = x as f32 - mean;
diff * diff
})
.sum::<f32>()
/ projection.len() as f32;
variance
}
/// Rotate image by specified angle
///
/// # Arguments
/// * `image` - Input grayscale image
/// * `angle` - Rotation angle in degrees (positive = clockwise)
///
/// # Returns
/// Rotated image with bilinear interpolation
///
/// # Example
/// ```no_run
/// use ruvector_scipix::preprocess::rotation::rotate_image;
/// # use image::GrayImage;
/// # let image = GrayImage::new(100, 100);
/// let rotated = rotate_image(&image, 15.0).unwrap();
/// ```
pub fn rotate_image(image: &GrayImage, angle: f32) -> Result<GrayImage> {
if angle.abs() < 0.01 {
// No rotation needed
return Ok(image.clone());
}
let radians = -angle.to_radians(); // Negate for correct direction
let rotated = rotate_about_center(
image,
radians,
Interpolation::Bilinear,
Luma([255]), // White background
);
Ok(rotated)
}
/// Detect rotation with confidence score
///
/// Returns tuple of (angle, confidence) where confidence is 0.0-1.0
pub fn detect_rotation_with_confidence(image: &GrayImage) -> Result<(f32, f32)> {
let angle = detect_rotation(image)?;
// Calculate confidence based on projection profile variance difference
let current_score = calculate_projection_score(image, angle);
let baseline_score = calculate_projection_score(image, 0.0);
// Confidence is relative improvement over baseline
let confidence = if baseline_score > 0.0 {
(current_score / baseline_score).min(1.0)
} else {
0.5 // Default moderate confidence
};
Ok((angle, confidence))
}
/// Auto-rotate image only if confidence is above threshold
///
/// # Arguments
/// * `image` - Input grayscale image
/// * `confidence_threshold` - Minimum confidence (0.0-1.0) to apply rotation
///
/// # Returns
/// Tuple of (rotated_image, angle_applied, confidence)
pub fn auto_rotate(image: &GrayImage, confidence_threshold: f32) -> Result<(GrayImage, f32, f32)> {
let (angle, confidence) = detect_rotation_with_confidence(image)?;
if confidence >= confidence_threshold && angle.abs() > 0.5 {
let rotated = rotate_image(image, -angle)?;
Ok((rotated, angle, confidence))
} else {
Ok((image.clone(), 0.0, confidence))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_text_image() -> GrayImage {
let mut img = GrayImage::new(200, 100);
// Fill with white
for pixel in img.pixels_mut() {
*pixel = Luma([255]);
}
// Draw some horizontal lines (simulating text)
for y in [20, 25, 50, 55] {
for x in 10..190 {
img.put_pixel(x, y, Luma([0]));
}
}
img
}
#[test]
fn test_detect_rotation_straight() {
let img = create_text_image();
let angle = detect_rotation(&img);
assert!(angle.is_ok());
let a = angle.unwrap();
// Should detect near-zero rotation
assert!(a.abs() < 10.0);
}
#[test]
fn test_rotate_image() {
let img = create_text_image();
// Rotate by 15 degrees
let rotated = rotate_image(&img, 15.0);
assert!(rotated.is_ok());
let result = rotated.unwrap();
assert_eq!(result.dimensions(), img.dimensions());
}
#[test]
fn test_rotate_no_change() {
let img = create_text_image();
// Rotate by ~0 degrees
let rotated = rotate_image(&img, 0.001);
assert!(rotated.is_ok());
let result = rotated.unwrap();
assert_eq!(result.dimensions(), img.dimensions());
}
#[test]
fn test_rotation_confidence() {
let img = create_text_image();
let result = detect_rotation_with_confidence(&img);
assert!(result.is_ok());
let (angle, confidence) = result.unwrap();
assert!(confidence >= 0.0 && confidence <= 1.0);
println!(
"Detected angle: {:.2}°, confidence: {:.2}",
angle, confidence
);
}
#[test]
fn test_auto_rotate_with_threshold() {
let img = create_text_image();
// High threshold - should not rotate if confidence is low
let result = auto_rotate(&img, 0.95);
assert!(result.is_ok());
let (rotated, angle, confidence) = result.unwrap();
assert_eq!(rotated.dimensions(), img.dimensions());
println!(
"Auto-rotate: angle={:.2}°, confidence={:.2}",
angle, confidence
);
}
#[test]
fn test_projection_variance() {
let projection = vec![10, 50, 100, 50, 10];
let variance = calculate_variance(&projection);
assert!(variance > 0.0);
}
#[test]
fn test_rotation_small_image_error() {
let small_img = GrayImage::new(5, 5);
let result = detect_rotation(&small_img);
assert!(result.is_err());
}
#[test]
fn test_rotation_roundtrip() {
let img = create_text_image();
// Rotate and unrotate
let rotated = rotate_image(&img, 30.0).unwrap();
let unrotated = rotate_image(&rotated, -30.0).unwrap();
assert_eq!(unrotated.dimensions(), img.dimensions());
}
}

View File

@@ -0,0 +1,483 @@
//! Text region detection and segmentation
use super::{RegionType, Result, TextRegion};
use image::GrayImage;
use std::collections::{HashMap, HashSet};
/// Find text regions in a binary or grayscale image
///
/// Uses connected component analysis and geometric heuristics to identify
/// text regions and classify them by type (text, math, table, etc.)
///
/// # Arguments
/// * `image` - Input grayscale or binary image
/// * `min_region_size` - Minimum region area in pixels
///
/// # Returns
/// Vector of detected text regions with bounding boxes
///
/// # Example
/// ```no_run
/// use ruvector_scipix::preprocess::segmentation::find_text_regions;
/// # use image::GrayImage;
/// # let image = GrayImage::new(100, 100);
/// let regions = find_text_regions(&image, 100).unwrap();
/// println!("Found {} regions", regions.len());
/// ```
pub fn find_text_regions(image: &GrayImage, min_region_size: u32) -> Result<Vec<TextRegion>> {
// Find connected components
let components = connected_components(image);
// Extract bounding boxes for each component
let bboxes = extract_bounding_boxes(&components);
// Filter by size and merge overlapping regions
let filtered = filter_by_size(bboxes, min_region_size);
let merged = merge_overlapping_regions(filtered, 10);
// Find text lines and group components
let text_lines = find_text_lines(image, &merged);
// Classify regions and create TextRegion objects
let regions = classify_regions(image, text_lines);
Ok(regions)
}
/// Connected component labeling using flood-fill algorithm
///
/// Returns labeled image where each connected component has a unique ID
fn connected_components(image: &GrayImage) -> Vec<Vec<u32>> {
let (width, height) = image.dimensions();
let mut labels = vec![vec![0u32; width as usize]; height as usize];
let mut current_label = 1u32;
for y in 0..height {
for x in 0..width {
if labels[y as usize][x as usize] == 0 && image.get_pixel(x, y)[0] < 128 {
// Found unlabeled foreground pixel, start flood fill
flood_fill(image, &mut labels, x, y, current_label);
current_label += 1;
}
}
}
labels
}
/// Flood fill algorithm for connected component labeling
fn flood_fill(image: &GrayImage, labels: &mut [Vec<u32>], start_x: u32, start_y: u32, label: u32) {
let (width, height) = image.dimensions();
let mut stack = vec![(start_x, start_y)];
while let Some((x, y)) = stack.pop() {
if x >= width || y >= height {
continue;
}
if labels[y as usize][x as usize] != 0 || image.get_pixel(x, y)[0] >= 128 {
continue;
}
labels[y as usize][x as usize] = label;
// Add 4-connected neighbors
if x > 0 {
stack.push((x - 1, y));
}
if x < width - 1 {
stack.push((x + 1, y));
}
if y > 0 {
stack.push((x, y - 1));
}
if y < height - 1 {
stack.push((x, y + 1));
}
}
}
/// Extract bounding boxes for each labeled component
fn extract_bounding_boxes(labels: &[Vec<u32>]) -> HashMap<u32, (u32, u32, u32, u32)> {
let mut bboxes: HashMap<u32, (u32, u32, u32, u32)> = HashMap::new();
for (y, row) in labels.iter().enumerate() {
for (x, &label) in row.iter().enumerate() {
if label == 0 {
continue;
}
let bbox = bboxes
.entry(label)
.or_insert((x as u32, y as u32, x as u32, y as u32));
// Update bounding box
bbox.0 = bbox.0.min(x as u32); // min_x
bbox.1 = bbox.1.min(y as u32); // min_y
bbox.2 = bbox.2.max(x as u32); // max_x
bbox.3 = bbox.3.max(y as u32); // max_y
}
}
// Convert to (x, y, width, height) format
bboxes
.into_iter()
.map(|(label, (min_x, min_y, max_x, max_y))| {
let width = max_x - min_x + 1;
let height = max_y - min_y + 1;
(label, (min_x, min_y, width, height))
})
.collect()
}
/// Filter regions by minimum size
fn filter_by_size(
bboxes: HashMap<u32, (u32, u32, u32, u32)>,
min_size: u32,
) -> Vec<(u32, u32, u32, u32)> {
bboxes
.into_values()
.filter(|(_, _, w, h)| w * h >= min_size)
.collect()
}
/// Merge overlapping or nearby regions
///
/// # Arguments
/// * `regions` - Vector of bounding boxes (x, y, width, height)
/// * `merge_distance` - Maximum distance to merge regions
pub fn merge_overlapping_regions(
regions: Vec<(u32, u32, u32, u32)>,
merge_distance: u32,
) -> Vec<(u32, u32, u32, u32)> {
if regions.is_empty() {
return regions;
}
let mut merged = Vec::new();
let mut used = HashSet::new();
for i in 0..regions.len() {
if used.contains(&i) {
continue;
}
let mut current = regions[i];
let mut changed = true;
while changed {
changed = false;
for j in (i + 1)..regions.len() {
if used.contains(&j) {
continue;
}
if boxes_overlap_or_close(&current, &regions[j], merge_distance) {
current = merge_boxes(&current, &regions[j]);
used.insert(j);
changed = true;
}
}
}
merged.push(current);
used.insert(i);
}
merged
}
/// Check if two bounding boxes overlap or are close
fn boxes_overlap_or_close(
box1: &(u32, u32, u32, u32),
box2: &(u32, u32, u32, u32),
distance: u32,
) -> bool {
let (x1, y1, w1, h1) = *box1;
let (x2, y2, w2, h2) = *box2;
let x1_end = x1 + w1;
let y1_end = y1 + h1;
let x2_end = x2 + w2;
let y2_end = y2 + h2;
// Check for overlap or proximity
let x_overlap = (x1 <= x2_end + distance) && (x2 <= x1_end + distance);
let y_overlap = (y1 <= y2_end + distance) && (y2 <= y1_end + distance);
x_overlap && y_overlap
}
/// Merge two bounding boxes
fn merge_boxes(box1: &(u32, u32, u32, u32), box2: &(u32, u32, u32, u32)) -> (u32, u32, u32, u32) {
let (x1, y1, w1, h1) = *box1;
let (x2, y2, w2, h2) = *box2;
let min_x = x1.min(x2);
let min_y = y1.min(y2);
let max_x = (x1 + w1).max(x2 + w2);
let max_y = (y1 + h1).max(y2 + h2);
(min_x, min_y, max_x - min_x, max_y - min_y)
}
/// Find text lines using projection profiles
///
/// Groups regions into lines based on vertical alignment
pub fn find_text_lines(
_image: &GrayImage,
regions: &[(u32, u32, u32, u32)],
) -> Vec<Vec<(u32, u32, u32, u32)>> {
if regions.is_empty() {
return Vec::new();
}
// Sort regions by y-coordinate
let mut sorted_regions = regions.to_vec();
sorted_regions.sort_by_key(|r| r.1);
let mut lines = Vec::new();
let mut current_line = vec![sorted_regions[0]];
for region in sorted_regions.iter().skip(1) {
let (_, y, _, h) = region;
let (_, prev_y, _, prev_h) = current_line.last().unwrap();
// Check if region is on the same line (vertical overlap)
let line_height = (*prev_h).max(*h);
let distance = if y > prev_y { y - prev_y } else { prev_y - y };
if distance < line_height / 2 {
current_line.push(*region);
} else {
lines.push(current_line.clone());
current_line = vec![*region];
}
}
if !current_line.is_empty() {
lines.push(current_line);
}
lines
}
/// Classify regions by type (text, math, table, etc.)
fn classify_regions(
image: &GrayImage,
text_lines: Vec<Vec<(u32, u32, u32, u32)>>,
) -> Vec<TextRegion> {
let mut regions = Vec::new();
for line in text_lines {
for bbox in line {
let (x, y, width, height) = bbox;
// Calculate features for classification
let aspect_ratio = width as f32 / height as f32;
let density = calculate_density(image, bbox);
// Simple heuristic classification
let region_type = if aspect_ratio > 10.0 {
// Very wide region might be a table or figure caption
RegionType::Table
} else if aspect_ratio < 0.5 && height > 50 {
// Tall region might be a figure
RegionType::Figure
} else if density > 0.3 && height < 30 {
// Dense, small region likely math
RegionType::Math
} else {
// Default to text
RegionType::Text
};
regions.push(TextRegion {
region_type,
bbox: (x, y, width, height),
confidence: 0.8, // Default confidence
text_height: height as f32,
baseline_angle: 0.0,
});
}
}
regions
}
/// Calculate pixel density in a region
fn calculate_density(image: &GrayImage, bbox: (u32, u32, u32, u32)) -> f32 {
let (x, y, width, height) = bbox;
let total_pixels = (width * height) as f32;
if total_pixels == 0.0 {
return 0.0;
}
let mut foreground_pixels = 0;
for py in y..(y + height) {
for px in x..(x + width) {
if image.get_pixel(px, py)[0] < 128 {
foreground_pixels += 1;
}
}
}
foreground_pixels as f32 / total_pixels
}
#[cfg(test)]
mod tests {
use super::*;
use image::Luma;
fn create_test_image_with_rectangles() -> GrayImage {
let mut img = GrayImage::new(200, 200);
// Fill with white
for pixel in img.pixels_mut() {
*pixel = Luma([255]);
}
// Draw some black rectangles (simulating text regions)
for y in 20..40 {
for x in 20..100 {
img.put_pixel(x, y, Luma([0]));
}
}
for y in 60..80 {
for x in 20..120 {
img.put_pixel(x, y, Luma([0]));
}
}
for y in 100..120 {
for x in 20..80 {
img.put_pixel(x, y, Luma([0]));
}
}
img
}
#[test]
fn test_find_text_regions() {
let img = create_test_image_with_rectangles();
let regions = find_text_regions(&img, 100);
assert!(regions.is_ok());
let r = regions.unwrap();
// Should find at least 3 regions
assert!(r.len() >= 3);
for region in r {
println!("Region: {:?} at {:?}", region.region_type, region.bbox);
}
}
#[test]
fn test_connected_components() {
let img = create_test_image_with_rectangles();
let components = connected_components(&img);
// Check that we have non-zero labels
let max_label = components
.iter()
.flat_map(|row| row.iter())
.max()
.unwrap_or(&0);
assert!(*max_label > 0);
}
#[test]
fn test_merge_overlapping_regions() {
let regions = vec![(10, 10, 50, 20), (40, 10, 50, 20), (100, 100, 30, 30)];
let merged = merge_overlapping_regions(regions, 10);
// First two should merge, third stays separate
assert_eq!(merged.len(), 2);
}
#[test]
fn test_merge_boxes() {
let box1 = (10, 10, 50, 20);
let box2 = (40, 15, 30, 25);
let merged = merge_boxes(&box1, &box2);
assert_eq!(merged.0, 10); // min x
assert_eq!(merged.1, 10); // min y
assert!(merged.2 >= 50); // width
assert!(merged.3 >= 25); // height
}
#[test]
fn test_boxes_overlap() {
let box1 = (10, 10, 50, 20);
let box2 = (40, 10, 50, 20);
assert!(boxes_overlap_or_close(&box1, &box2, 0));
assert!(boxes_overlap_or_close(&box1, &box2, 10));
}
#[test]
fn test_boxes_dont_overlap() {
let box1 = (10, 10, 20, 20);
let box2 = (100, 100, 20, 20);
assert!(!boxes_overlap_or_close(&box1, &box2, 0));
}
#[test]
fn test_find_text_lines() {
let regions = vec![
(10, 10, 50, 20),
(70, 12, 50, 20),
(10, 50, 50, 20),
(70, 52, 50, 20),
];
let img = GrayImage::new(200, 100);
let lines = find_text_lines(&img, &regions);
// Should find 2 lines
assert_eq!(lines.len(), 2);
assert_eq!(lines[0].len(), 2);
assert_eq!(lines[1].len(), 2);
}
#[test]
fn test_calculate_density() {
let mut img = GrayImage::new(100, 100);
// Fill region with 50% black pixels
for y in 10..30 {
for x in 10..30 {
let val = if (x + y) % 2 == 0 { 0 } else { 255 };
img.put_pixel(x, y, Luma([val]));
}
}
let density = calculate_density(&img, (10, 10, 20, 20));
assert!((density - 0.5).abs() < 0.1);
}
#[test]
fn test_filter_by_size() {
let mut bboxes = HashMap::new();
bboxes.insert(1, (10, 10, 50, 50)); // 2500 pixels
bboxes.insert(2, (100, 100, 10, 10)); // 100 pixels
bboxes.insert(3, (200, 200, 30, 30)); // 900 pixels
let filtered = filter_by_size(bboxes, 500);
// Should keep regions 1 and 3
assert_eq!(filtered.len(), 2);
}
}

View File

@@ -0,0 +1,400 @@
//! Image transformation functions for preprocessing
use super::{PreprocessError, Result};
use image::{DynamicImage, GrayImage, Luma};
use imageproc::filter::gaussian_blur_f32;
use std::f32;
/// Convert image to grayscale
///
/// # Arguments
/// * `image` - Input color or grayscale image
///
/// # Returns
/// Grayscale image
pub fn to_grayscale(image: &DynamicImage) -> GrayImage {
image.to_luma8()
}
/// Apply Gaussian blur for noise reduction
///
/// # Arguments
/// * `image` - Input grayscale image
/// * `sigma` - Standard deviation of Gaussian kernel
///
/// # Returns
/// Blurred image
///
/// # Example
/// ```no_run
/// use ruvector_scipix::preprocess::transforms::gaussian_blur;
/// # use image::GrayImage;
/// # let image = GrayImage::new(100, 100);
/// let blurred = gaussian_blur(&image, 1.5).unwrap();
/// ```
pub fn gaussian_blur(image: &GrayImage, sigma: f32) -> Result<GrayImage> {
if sigma <= 0.0 {
return Err(PreprocessError::InvalidParameters(
"Sigma must be positive".to_string(),
));
}
Ok(gaussian_blur_f32(image, sigma))
}
/// Sharpen image using unsharp mask
///
/// # Arguments
/// * `image` - Input grayscale image
/// * `sigma` - Gaussian blur sigma
/// * `amount` - Sharpening strength (typically 0.5-2.0)
///
/// # Returns
/// Sharpened image
pub fn sharpen(image: &GrayImage, sigma: f32, amount: f32) -> Result<GrayImage> {
if sigma <= 0.0 || amount < 0.0 {
return Err(PreprocessError::InvalidParameters(
"Invalid sharpening parameters".to_string(),
));
}
let blurred = gaussian_blur_f32(image, sigma);
let (width, height) = image.dimensions();
let mut result = GrayImage::new(width, height);
for y in 0..height {
for x in 0..width {
let original = image.get_pixel(x, y)[0] as f32;
let blur = blurred.get_pixel(x, y)[0] as f32;
// Unsharp mask: original + amount * (original - blurred)
let sharpened = original + amount * (original - blur);
let clamped = sharpened.clamp(0.0, 255.0) as u8;
result.put_pixel(x, y, Luma([clamped]));
}
}
Ok(result)
}
/// Calculate optimal threshold using Otsu's method
///
/// Implements full Otsu's algorithm for automatic threshold selection
/// based on maximizing inter-class variance.
///
/// # Arguments
/// * `image` - Input grayscale image
///
/// # Returns
/// Optimal threshold value (0-255)
///
/// # Example
/// ```no_run
/// use ruvector_scipix::preprocess::transforms::otsu_threshold;
/// # use image::GrayImage;
/// # let image = GrayImage::new(100, 100);
/// let threshold = otsu_threshold(&image).unwrap();
/// println!("Optimal threshold: {}", threshold);
/// ```
pub fn otsu_threshold(image: &GrayImage) -> Result<u8> {
// Calculate histogram
let mut histogram = [0u32; 256];
for pixel in image.pixels() {
histogram[pixel[0] as usize] += 1;
}
let total_pixels = (image.width() * image.height()) as f64;
// Calculate cumulative sums
let mut sum_total = 0.0;
for (i, &count) in histogram.iter().enumerate() {
sum_total += (i as f64) * (count as f64);
}
let mut sum_background = 0.0;
let mut weight_background = 0.0;
let mut max_variance = 0.0;
let mut threshold = 0u8;
// Find threshold that maximizes inter-class variance
for (t, &count) in histogram.iter().enumerate() {
weight_background += count as f64;
if weight_background == 0.0 {
continue;
}
let weight_foreground = total_pixels - weight_background;
if weight_foreground == 0.0 {
break;
}
sum_background += (t as f64) * (count as f64);
let mean_background = sum_background / weight_background;
let mean_foreground = (sum_total - sum_background) / weight_foreground;
// Inter-class variance
let variance =
weight_background * weight_foreground * (mean_background - mean_foreground).powi(2);
if variance > max_variance {
max_variance = variance;
threshold = t as u8;
}
}
Ok(threshold)
}
/// Apply binary thresholding
///
/// # Arguments
/// * `image` - Input grayscale image
/// * `threshold` - Threshold value (0-255)
///
/// # Returns
/// Binary image (0 or 255)
pub fn threshold(image: &GrayImage, threshold_val: u8) -> GrayImage {
let (width, height) = image.dimensions();
let mut result = GrayImage::new(width, height);
for y in 0..height {
for x in 0..width {
let pixel = image.get_pixel(x, y)[0];
let value = if pixel >= threshold_val { 255 } else { 0 };
result.put_pixel(x, y, Luma([value]));
}
}
result
}
/// Apply adaptive thresholding using local window statistics
///
/// Uses a sliding window to calculate local mean and applies threshold
/// relative to local statistics. Better for images with varying illumination.
///
/// # Arguments
/// * `image` - Input grayscale image
/// * `window_size` - Size of local window (must be odd)
///
/// # Returns
/// Binary image with adaptive thresholding applied
///
/// # Example
/// ```no_run
/// use ruvector_scipix::preprocess::transforms::adaptive_threshold;
/// # use image::GrayImage;
/// # let image = GrayImage::new(100, 100);
/// let binary = adaptive_threshold(&image, 15).unwrap();
/// ```
pub fn adaptive_threshold(image: &GrayImage, window_size: u32) -> Result<GrayImage> {
if window_size % 2 == 0 {
return Err(PreprocessError::InvalidParameters(
"Window size must be odd".to_string(),
));
}
let (width, height) = image.dimensions();
let mut result = GrayImage::new(width, height);
let half_window = (window_size / 2) as i32;
// Use integral image for fast window sum calculation
let integral = compute_integral_image(image);
for y in 0..height as i32 {
for x in 0..width as i32 {
// Define window bounds
let x1 = (x - half_window).max(0);
let y1 = (y - half_window).max(0);
let x2 = (x + half_window + 1).min(width as i32);
let y2 = (y + half_window + 1).min(height as i32);
// Calculate mean using integral image
let area = ((x2 - x1) * (y2 - y1)) as f64;
let sum = get_integral_sum(&integral, x1, y1, x2, y2);
let mean = (sum as f64 / area) as u8;
// Apply threshold with small bias
let pixel = image.get_pixel(x as u32, y as u32)[0];
let bias = 5; // Small bias to reduce noise
let value = if pixel >= mean.saturating_sub(bias) {
255
} else {
0
};
result.put_pixel(x as u32, y as u32, Luma([value]));
}
}
Ok(result)
}
/// Compute integral image for fast rectangle sum queries
fn compute_integral_image(image: &GrayImage) -> Vec<Vec<u64>> {
let (width, height) = image.dimensions();
let mut integral = vec![vec![0u64; width as usize + 1]; height as usize + 1];
for y in 1..=height as usize {
for x in 1..=width as usize {
let pixel = image.get_pixel(x as u32 - 1, y as u32 - 1)[0] as u64;
integral[y][x] =
pixel + integral[y - 1][x] + integral[y][x - 1] - integral[y - 1][x - 1];
}
}
integral
}
/// Get sum of rectangle in integral image
fn get_integral_sum(integral: &[Vec<u64>], x1: i32, y1: i32, x2: i32, y2: i32) -> u64 {
let x1 = x1 as usize;
let y1 = y1 as usize;
let x2 = x2 as usize;
let y2 = y2 as usize;
integral[y2][x2] + integral[y1][x1] - integral[y1][x2] - integral[y2][x1]
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
fn create_gradient_image(width: u32, height: u32) -> GrayImage {
let mut img = GrayImage::new(width, height);
for y in 0..height {
for x in 0..width {
let val = ((x + y) * 255 / (width + height)) as u8;
img.put_pixel(x, y, Luma([val]));
}
}
img
}
#[test]
fn test_to_grayscale() {
let img = DynamicImage::new_rgb8(100, 100);
let gray = to_grayscale(&img);
assert_eq!(gray.dimensions(), (100, 100));
}
#[test]
fn test_gaussian_blur() {
let img = create_gradient_image(50, 50);
let blurred = gaussian_blur(&img, 1.0);
assert!(blurred.is_ok());
let result = blurred.unwrap();
assert_eq!(result.dimensions(), img.dimensions());
}
#[test]
fn test_gaussian_blur_invalid_sigma() {
let img = create_gradient_image(50, 50);
let result = gaussian_blur(&img, -1.0);
assert!(result.is_err());
}
#[test]
fn test_sharpen() {
let img = create_gradient_image(50, 50);
let sharpened = sharpen(&img, 1.0, 1.5);
assert!(sharpened.is_ok());
let result = sharpened.unwrap();
assert_eq!(result.dimensions(), img.dimensions());
}
#[test]
fn test_otsu_threshold() {
// Create bimodal image (good for Otsu)
let mut img = GrayImage::new(100, 100);
for y in 0..100 {
for x in 0..100 {
let val = if x < 50 { 50 } else { 200 };
img.put_pixel(x, y, Luma([val]));
}
}
let threshold = otsu_threshold(&img);
assert!(threshold.is_ok());
let t = threshold.unwrap();
// Should be somewhere between the two values (not necessarily strictly between)
// Otsu finds optimal threshold which could be at boundary
assert!(
t >= 50 && t <= 200,
"threshold {} should be between 50 and 200",
t
);
}
#[test]
fn test_threshold() {
let img = create_gradient_image(100, 100);
let binary = threshold(&img, 128);
assert_eq!(binary.dimensions(), img.dimensions());
// Check that output is binary
for pixel in binary.pixels() {
let val = pixel[0];
assert!(val == 0 || val == 255);
}
}
#[test]
fn test_adaptive_threshold() {
let img = create_gradient_image(100, 100);
let binary = adaptive_threshold(&img, 15);
assert!(binary.is_ok());
let result = binary.unwrap();
assert_eq!(result.dimensions(), img.dimensions());
// Check binary output
for pixel in result.pixels() {
let val = pixel[0];
assert!(val == 0 || val == 255);
}
}
#[test]
fn test_adaptive_threshold_invalid_window() {
let img = create_gradient_image(50, 50);
let result = adaptive_threshold(&img, 16); // Even number
assert!(result.is_err());
}
#[test]
fn test_integral_image() {
let mut img = GrayImage::new(3, 3);
for y in 0..3 {
for x in 0..3 {
img.put_pixel(x, y, Luma([1]));
}
}
let integral = compute_integral_image(&img);
// Check 3x3 sum
let sum = get_integral_sum(&integral, 0, 0, 3, 3);
assert_eq!(sum, 9); // 3x3 image with all 1s
}
#[test]
fn test_threshold_extremes() {
let img = create_gradient_image(100, 100);
// Threshold at 0 should make everything white
let binary = threshold(&img, 0);
assert!(binary.pixels().all(|p| p[0] == 255));
// Threshold at 255 should make everything black
let binary = threshold(&img, 255);
assert!(binary.pixels().all(|p| p[0] == 0));
}
}

View File

@@ -0,0 +1,189 @@
//! JavaScript API for Scipix OCR
use once_cell::sync::OnceCell;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use wasm_bindgen::prelude::*;
use web_sys::{HtmlCanvasElement, ImageData};
use crate::wasm::canvas::CanvasProcessor;
use crate::wasm::memory::WasmBuffer;
use crate::wasm::types::{OcrResult, RecognitionFormat};
static PROCESSOR: OnceCell<Arc<CanvasProcessor>> = OnceCell::new();
/// Main WASM API for Scipix OCR
#[wasm_bindgen]
pub struct ScipixWasm {
processor: Arc<CanvasProcessor>,
format: RecognitionFormat,
confidence_threshold: f32,
}
#[wasm_bindgen]
impl ScipixWasm {
/// Create a new ScipixWasm instance
#[wasm_bindgen(constructor)]
pub async fn new() -> Result<ScipixWasm, JsValue> {
let processor = PROCESSOR
.get_or_init(|| Arc::new(CanvasProcessor::new()))
.clone();
Ok(ScipixWasm {
processor,
format: RecognitionFormat::Both,
confidence_threshold: 0.5,
})
}
/// Recognize text from raw image data
#[wasm_bindgen]
pub async fn recognize(&self, image_data: &[u8]) -> Result<JsValue, JsValue> {
let buffer = WasmBuffer::from_slice(image_data);
let result = self
.processor
.process_image_bytes(buffer.as_slice(), self.format)
.await
.map_err(|e| JsValue::from_str(&format!("Recognition failed: {}", e)))?;
// Filter by confidence threshold
let filtered = self.filter_by_confidence(result);
serde_wasm_bindgen::to_value(&filtered)
.map_err(|e| JsValue::from_str(&format!("Serialization failed: {}", e)))
}
/// Recognize text from HTML Canvas element
#[wasm_bindgen(js_name = recognizeFromCanvas)]
pub async fn recognize_from_canvas(
&self,
canvas: &HtmlCanvasElement,
) -> Result<JsValue, JsValue> {
let image_data = self
.processor
.extract_canvas_image(canvas)
.map_err(|e| JsValue::from_str(&format!("Canvas extraction failed: {}", e)))?;
let result = self
.processor
.process_image_data(&image_data, self.format)
.await
.map_err(|e| JsValue::from_str(&format!("Recognition failed: {}", e)))?;
let filtered = self.filter_by_confidence(result);
serde_wasm_bindgen::to_value(&filtered)
.map_err(|e| JsValue::from_str(&format!("Serialization failed: {}", e)))
}
/// Recognize text from base64-encoded image
#[wasm_bindgen(js_name = recognizeBase64)]
pub async fn recognize_base64(&self, base64: &str) -> Result<JsValue, JsValue> {
// Remove data URL prefix if present
let base64_data = if base64.contains(',') {
base64.split(',').nth(1).unwrap_or(base64)
} else {
base64
};
let image_bytes = base64::decode(base64_data)
.map_err(|e| JsValue::from_str(&format!("Base64 decode failed: {}", e)))?;
self.recognize(&image_bytes).await
}
/// Recognize text from ImageData object
#[wasm_bindgen(js_name = recognizeImageData)]
pub async fn recognize_image_data(&self, image_data: &ImageData) -> Result<JsValue, JsValue> {
let result = self
.processor
.process_image_data(image_data, self.format)
.await
.map_err(|e| JsValue::from_str(&format!("Recognition failed: {}", e)))?;
let filtered = self.filter_by_confidence(result);
serde_wasm_bindgen::to_value(&filtered)
.map_err(|e| JsValue::from_str(&format!("Serialization failed: {}", e)))
}
/// Set the output format (text, latex, or both)
#[wasm_bindgen(js_name = setFormat)]
pub fn set_format(&mut self, format: &str) {
self.format = match format.to_lowercase().as_str() {
"text" => RecognitionFormat::Text,
"latex" => RecognitionFormat::Latex,
"both" => RecognitionFormat::Both,
_ => RecognitionFormat::Both,
};
}
/// Set the confidence threshold (0.0 - 1.0)
#[wasm_bindgen(js_name = setConfidenceThreshold)]
pub fn set_confidence_threshold(&mut self, threshold: f32) {
self.confidence_threshold = threshold.clamp(0.0, 1.0);
}
/// Get the current confidence threshold
#[wasm_bindgen(js_name = getConfidenceThreshold)]
pub fn get_confidence_threshold(&self) -> f32 {
self.confidence_threshold
}
/// Get the version of the library
#[wasm_bindgen(js_name = getVersion)]
pub fn get_version(&self) -> String {
env!("CARGO_PKG_VERSION").to_string()
}
/// Get supported output formats
#[wasm_bindgen(js_name = getSupportedFormats)]
pub fn get_supported_formats(&self) -> Vec<JsValue> {
vec![
JsValue::from_str("text"),
JsValue::from_str("latex"),
JsValue::from_str("both"),
]
}
/// Batch process multiple images
#[wasm_bindgen(js_name = recognizeBatch)]
pub async fn recognize_batch(&self, images: Vec<JsValue>) -> Result<JsValue, JsValue> {
let mut results = Vec::new();
for img in images {
// Try to process as Uint8Array
if let Ok(bytes) = js_sys::Uint8Array::new(&img).to_vec() {
match self.recognize(&bytes).await {
Ok(result) => results.push(result),
Err(e) => {
web_sys::console::warn_1(&JsValue::from_str(&format!(
"Failed to process image: {:?}",
e
)));
results.push(JsValue::NULL);
}
}
}
}
Ok(js_sys::Array::from_iter(results).into())
}
// Private helper methods
fn filter_by_confidence(&self, mut result: OcrResult) -> OcrResult {
if result.confidence < self.confidence_threshold {
result.text = String::new();
result.latex = None;
}
result
}
}
/// Create a new ScipixWasm instance (factory function)
#[wasm_bindgen(js_name = createScipix)]
pub async fn create_scipix() -> Result<ScipixWasm, JsValue> {
ScipixWasm::new().await
}

View File

@@ -0,0 +1,217 @@
//! Canvas and ImageData handling for WASM
use anyhow::{anyhow, Result};
use image::{DynamicImage, ImageBuffer, Rgba};
use wasm_bindgen::prelude::*;
use web_sys::{CanvasRenderingContext2d, HtmlCanvasElement, ImageData};
use crate::wasm::types::{OcrResult, RecognitionFormat};
/// Processor for canvas and image data
pub struct CanvasProcessor {
// Could add model loading here in the future
}
impl CanvasProcessor {
/// Create a new canvas processor
pub fn new() -> Self {
Self {}
}
/// Extract image data from HTML canvas element
pub fn extract_canvas_image(&self, canvas: &HtmlCanvasElement) -> Result<ImageData> {
let context = canvas
.get_context("2d")
.map_err(|_| anyhow!("Failed to get 2d context"))?
.ok_or_else(|| anyhow!("Context is null"))?
.dyn_into::<CanvasRenderingContext2d>()
.map_err(|_| anyhow!("Failed to cast to 2d context"))?;
let width = canvas.width();
let height = canvas.height();
context
.get_image_data(0.0, 0.0, width as f64, height as f64)
.map_err(|_| anyhow!("Failed to get image data"))
}
/// Convert ImageData to DynamicImage
pub fn image_data_to_dynamic(&self, image_data: &ImageData) -> Result<DynamicImage> {
let width = image_data.width();
let height = image_data.height();
let data = image_data.data();
let img_buffer = ImageBuffer::<Rgba<u8>, Vec<u8>>::from_raw(width, height, data.to_vec())
.ok_or_else(|| anyhow!("Failed to create image buffer"))?;
Ok(DynamicImage::ImageRgba8(img_buffer))
}
/// Process raw image bytes
pub async fn process_image_bytes(
&self,
image_bytes: &[u8],
format: RecognitionFormat,
) -> Result<OcrResult> {
// Decode image
let img = image::load_from_memory(image_bytes)
.map_err(|e| anyhow!("Failed to decode image: {}", e))?;
self.process_dynamic_image(&img, format).await
}
/// Process ImageData from canvas
pub async fn process_image_data(
&self,
image_data: &ImageData,
format: RecognitionFormat,
) -> Result<OcrResult> {
let img = self.image_data_to_dynamic(image_data)?;
self.process_dynamic_image(&img, format).await
}
/// Process a DynamicImage
async fn process_dynamic_image(
&self,
img: &DynamicImage,
format: RecognitionFormat,
) -> Result<OcrResult> {
// Convert to grayscale for processing
let gray = img.to_luma8();
// Apply preprocessing
let preprocessed = self.preprocess_image(&gray);
// Perform OCR (mock implementation for now)
// In a real implementation, this would run a model
let text = self.extract_text(&preprocessed)?;
let latex = if matches!(format, RecognitionFormat::Latex | RecognitionFormat::Both) {
Some(self.extract_latex(&preprocessed)?)
} else {
None
};
// Calculate confidence (simplified)
let confidence = self.calculate_confidence(&text, &latex);
Ok(OcrResult {
text,
latex,
confidence,
metadata: Some(serde_json::json!({
"width": img.width(),
"height": img.height(),
"format": format.to_string(),
})),
})
}
/// Preprocess image for OCR
fn preprocess_image(&self, img: &image::GrayImage) -> image::GrayImage {
// Apply simple thresholding
let mut output = img.clone();
for pixel in output.pixels_mut() {
let value = pixel.0[0];
pixel.0[0] = if value > 128 { 255 } else { 0 };
}
output
}
/// Extract plain text (mock implementation)
fn extract_text(&self, img: &image::GrayImage) -> Result<String> {
// This would normally run an OCR model
// For now, return a placeholder
Ok("Recognized text placeholder".to_string())
}
/// Extract LaTeX (mock implementation)
fn extract_latex(&self, img: &image::GrayImage) -> Result<String> {
// This would normally run a math OCR model
// For now, return a placeholder
Ok(r"\sum_{i=1}^{n} x_i".to_string())
}
/// Calculate confidence score
fn calculate_confidence(&self, text: &str, latex: &Option<String>) -> f32 {
// Simple heuristic: longer text = higher confidence
let text_score = (text.len() as f32 / 100.0).min(1.0);
let latex_score = latex
.as_ref()
.map(|l| (l.len() as f32 / 50.0).min(1.0))
.unwrap_or(0.0);
(text_score + latex_score) / 2.0
}
}
impl Default for CanvasProcessor {
fn default() -> Self {
Self::new()
}
}
/// Convert blob URL to image data
#[wasm_bindgen]
pub async fn blob_url_to_image_data(blob_url: &str) -> Result<ImageData, JsValue> {
use web_sys::{window, HtmlImageElement};
let window = window().ok_or_else(|| JsValue::from_str("No window"))?;
let document = window
.document()
.ok_or_else(|| JsValue::from_str("No document"))?;
// Create image element
let img =
HtmlImageElement::new().map_err(|_| JsValue::from_str("Failed to create image element"))?;
img.set_src(blob_url);
// Wait for image to load
let promise = js_sys::Promise::new(&mut |resolve, reject| {
let img_clone = img.clone();
let onload = Closure::wrap(Box::new(move || {
resolve.call1(&JsValue::NULL, &img_clone).unwrap();
}) as Box<dyn FnMut()>);
img.set_onload(Some(onload.as_ref().unchecked_ref()));
onload.forget();
let onerror = Closure::wrap(Box::new(move || {
reject
.call1(&JsValue::NULL, &JsValue::from_str("Image load failed"))
.unwrap();
}) as Box<dyn FnMut()>);
img.set_onerror(Some(onerror.as_ref().unchecked_ref()));
onerror.forget();
});
wasm_bindgen_futures::JsFuture::from(promise).await?;
// Create canvas and draw image
let canvas = document
.create_element("canvas")
.map_err(|_| JsValue::from_str("Failed to create canvas"))?
.dyn_into::<HtmlCanvasElement>()
.map_err(|_| JsValue::from_str("Failed to cast to canvas"))?;
canvas.set_width(img.natural_width());
canvas.set_height(img.natural_height());
let context = canvas
.get_context("2d")
.map_err(|_| JsValue::from_str("Failed to get 2d context"))?
.ok_or_else(|| JsValue::from_str("Context is null"))?
.dyn_into::<CanvasRenderingContext2d>()
.map_err(|_| JsValue::from_str("Failed to cast to 2d context"))?;
context
.draw_image_with_html_image_element(&img, 0.0, 0.0)
.map_err(|_| JsValue::from_str("Failed to draw image"))?;
context
.get_image_data(0.0, 0.0, canvas.width() as f64, canvas.height() as f64)
.map_err(|_| JsValue::from_str("Failed to get image data"))
}

View File

@@ -0,0 +1,218 @@
//! Memory management for WASM
use std::ops::Deref;
use wasm_bindgen::prelude::*;
/// Efficient buffer wrapper for WASM memory management
pub struct WasmBuffer {
data: Vec<u8>,
}
impl WasmBuffer {
/// Create a new buffer with capacity
pub fn with_capacity(capacity: usize) -> Self {
Self {
data: Vec::with_capacity(capacity),
}
}
/// Create buffer from slice (copies data)
pub fn from_slice(slice: &[u8]) -> Self {
Self {
data: slice.to_vec(),
}
}
/// Create buffer from Vec (takes ownership)
pub fn from_vec(data: Vec<u8>) -> Self {
Self { data }
}
/// Get the underlying slice
pub fn as_slice(&self) -> &[u8] {
&self.data
}
/// Get mutable slice
pub fn as_mut_slice(&mut self) -> &mut [u8] {
&mut self.data
}
/// Get length
pub fn len(&self) -> usize {
self.data.len()
}
/// Check if empty
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
/// Clear the buffer (keeps capacity)
pub fn clear(&mut self) {
self.data.clear();
}
/// Shrink to fit
pub fn shrink_to_fit(&mut self) {
self.data.shrink_to_fit();
}
/// Convert to Vec
pub fn into_vec(self) -> Vec<u8> {
self.data
}
}
impl Deref for WasmBuffer {
type Target = [u8];
fn deref(&self) -> &Self::Target {
&self.data
}
}
impl Drop for WasmBuffer {
fn drop(&mut self) {
// Explicitly clear to help WASM memory management
self.data.clear();
self.data.shrink_to_fit();
}
}
/// Shared memory for large images (uses SharedArrayBuffer when available)
#[wasm_bindgen]
pub struct SharedImageBuffer {
buffer: WasmBuffer,
width: u32,
height: u32,
}
#[wasm_bindgen]
impl SharedImageBuffer {
/// Create a new shared buffer
#[wasm_bindgen(constructor)]
pub fn new(width: u32, height: u32) -> Self {
let size = (width * height * 4) as usize; // RGBA
Self {
buffer: WasmBuffer::with_capacity(size),
width,
height,
}
}
/// Get width
#[wasm_bindgen(getter)]
pub fn width(&self) -> u32 {
self.width
}
/// Get height
#[wasm_bindgen(getter)]
pub fn height(&self) -> u32 {
self.height
}
/// Get buffer size
#[wasm_bindgen(js_name = bufferSize)]
pub fn buffer_size(&self) -> usize {
self.buffer.len()
}
/// Get buffer as Uint8Array
#[wasm_bindgen(js_name = getBuffer)]
pub fn get_buffer(&self) -> js_sys::Uint8Array {
js_sys::Uint8Array::from(self.buffer.as_slice())
}
/// Set buffer from Uint8Array
#[wasm_bindgen(js_name = setBuffer)]
pub fn set_buffer(&mut self, data: &js_sys::Uint8Array) {
self.buffer = WasmBuffer::from_vec(data.to_vec());
}
/// Clear the buffer
pub fn clear(&mut self) {
self.buffer.clear();
}
}
/// Memory pool for reusing buffers
pub struct MemoryPool {
buffers: Vec<WasmBuffer>,
max_size: usize,
}
impl MemoryPool {
/// Create a new memory pool
pub fn new(max_size: usize) -> Self {
Self {
buffers: Vec::with_capacity(max_size),
max_size,
}
}
/// Get a buffer from the pool or create a new one
pub fn acquire(&mut self, size: usize) -> WasmBuffer {
self.buffers
.pop()
.map(|mut buf| {
buf.clear();
buf
})
.unwrap_or_else(|| WasmBuffer::with_capacity(size))
}
/// Return a buffer to the pool
pub fn release(&mut self, mut buffer: WasmBuffer) {
if self.buffers.len() < self.max_size {
buffer.clear();
self.buffers.push(buffer);
}
// Otherwise drop the buffer
}
/// Clear all buffers from the pool
pub fn clear(&mut self) {
self.buffers.clear();
}
}
impl Default for MemoryPool {
fn default() -> Self {
Self::new(10)
}
}
/// Get memory usage statistics
#[wasm_bindgen(js_name = getMemoryStats)]
pub fn get_memory_stats() -> JsValue {
#[cfg(target_arch = "wasm32")]
{
use wasm_bindgen::JsValue;
// Try to get memory info from performance.memory (non-standard)
let performance = web_sys::window().and_then(|w| w.performance());
if let Some(perf) = performance {
serde_wasm_bindgen::to_value(&serde_json::json!({
"available": true,
"timestamp": perf.now(),
}))
.unwrap_or(JsValue::NULL)
} else {
JsValue::NULL
}
}
#[cfg(not(target_arch = "wasm32"))]
JsValue::NULL
}
/// Force garbage collection (hint to runtime)
#[wasm_bindgen(js_name = forceGC)]
pub fn force_gc() {
// This is just a hint; actual GC is controlled by the JS runtime
// In wasm-bindgen, we can't directly trigger GC
// But we can help by ensuring our memory is freed
}

View File

@@ -0,0 +1,49 @@
//! WebAssembly bindings for Scipix OCR
//!
//! This module provides WASM bindings with wasm-bindgen for browser-based OCR.
#![cfg(target_arch = "wasm32")]
pub mod api;
pub mod canvas;
pub mod memory;
pub mod types;
pub mod worker;
pub use api::ScipixWasm;
pub use types::*;
use wasm_bindgen::prelude::*;
/// Initialize the WASM module with panic hooks and allocator
#[wasm_bindgen(start)]
pub fn init() {
// Set panic hook for better error messages
#[cfg(feature = "console_error_panic_hook")]
console_error_panic_hook::set_once();
// Use wee_alloc for smaller binary size
#[cfg(feature = "wee_alloc")]
{
#[global_allocator]
static ALLOC: wee_alloc::WeeAlloc = wee_alloc::WeeAlloc::INIT;
}
// Initialize logging
tracing_wasm::set_as_global_default();
}
/// Get the version of the WASM module
#[wasm_bindgen]
pub fn version() -> String {
env!("CARGO_PKG_VERSION").to_string()
}
/// Check if the WASM module is ready
#[wasm_bindgen]
pub fn is_ready() -> bool {
true
}
// Re-export tracing-wasm for logging
use tracing_wasm;

View File

@@ -0,0 +1,179 @@
//! Type definitions for WASM API
use serde::{Deserialize, Serialize};
use wasm_bindgen::prelude::*;
/// OCR result returned to JavaScript
#[derive(Debug, Clone, Serialize, Deserialize)]
#[wasm_bindgen]
pub struct OcrResult {
/// Recognized plain text
pub text: String,
/// LaTeX representation (if applicable)
#[serde(skip_serializing_if = "Option::is_none")]
pub latex: Option<String>,
/// Confidence score (0.0 - 1.0)
pub confidence: f32,
/// Additional metadata
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
}
#[wasm_bindgen]
impl OcrResult {
/// Create a new OCR result
#[wasm_bindgen(constructor)]
pub fn new(text: String, confidence: f32) -> Self {
Self {
text,
latex: None,
confidence,
metadata: None,
}
}
/// Get the text
#[wasm_bindgen(getter)]
pub fn text(&self) -> String {
self.text.clone()
}
/// Get the LaTeX (if available)
#[wasm_bindgen(getter)]
pub fn latex(&self) -> Option<String> {
self.latex.clone()
}
/// Get the confidence score
#[wasm_bindgen(getter)]
pub fn confidence(&self) -> f32 {
self.confidence
}
/// Check if result has LaTeX
#[wasm_bindgen(js_name = hasLatex)]
pub fn has_latex(&self) -> bool {
self.latex.is_some()
}
/// Convert to JSON
#[wasm_bindgen(js_name = toJSON)]
pub fn to_json(&self) -> Result<JsValue, JsValue> {
serde_wasm_bindgen::to_value(self)
.map_err(|e| JsValue::from_str(&format!("Serialization failed: {}", e)))
}
}
/// Recognition output format
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum RecognitionFormat {
/// Plain text only
Text,
/// LaTeX only
Latex,
/// Both text and LaTeX
Both,
}
impl RecognitionFormat {
pub fn to_string(&self) -> String {
match self {
Self::Text => "text".to_string(),
Self::Latex => "latex".to_string(),
Self::Both => "both".to_string(),
}
}
}
impl Default for RecognitionFormat {
fn default() -> Self {
Self::Both
}
}
/// Processing options
#[derive(Debug, Clone, Serialize, Deserialize)]
#[wasm_bindgen]
pub struct ProcessingOptions {
/// Output format
pub format: String,
/// Confidence threshold
pub confidence_threshold: f32,
/// Enable preprocessing
pub preprocess: bool,
/// Enable postprocessing
pub postprocess: bool,
}
#[wasm_bindgen]
impl ProcessingOptions {
/// Create default options
#[wasm_bindgen(constructor)]
pub fn new() -> Self {
Self::default()
}
/// Set format
#[wasm_bindgen(js_name = setFormat)]
pub fn set_format(&mut self, format: String) {
self.format = format;
}
/// Set confidence threshold
#[wasm_bindgen(js_name = setConfidenceThreshold)]
pub fn set_confidence_threshold(&mut self, threshold: f32) {
self.confidence_threshold = threshold;
}
}
impl Default for ProcessingOptions {
fn default() -> Self {
Self {
format: "both".to_string(),
confidence_threshold: 0.5,
preprocess: true,
postprocess: true,
}
}
}
/// Error types for WASM API
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum WasmError {
/// Image decoding error
ImageDecode(String),
/// Processing error
Processing(String),
/// Invalid input
InvalidInput(String),
/// Not initialized
NotInitialized,
}
impl std::fmt::Display for WasmError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ImageDecode(msg) => write!(f, "Image decode error: {}", msg),
Self::Processing(msg) => write!(f, "Processing error: {}", msg),
Self::InvalidInput(msg) => write!(f, "Invalid input: {}", msg),
Self::NotInitialized => write!(f, "WASM module not initialized"),
}
}
}
impl std::error::Error for WasmError {}
impl From<WasmError> for JsValue {
fn from(error: WasmError) -> Self {
JsValue::from_str(&error.to_string())
}
}

View File

@@ -0,0 +1,243 @@
//! Web Worker support for off-main-thread OCR processing
use once_cell::sync::OnceCell;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use wasm_bindgen::prelude::*;
use web_sys::{DedicatedWorkerGlobalScope, MessageEvent};
use crate::wasm::api::ScipixWasm;
use crate::wasm::types::RecognitionFormat;
static WORKER_INSTANCE: OnceCell<Arc<ScipixWasm>> = OnceCell::new();
/// Messages sent from main thread to worker
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum WorkerRequest {
/// Initialize the worker
Init,
/// Process an image
Process {
id: String,
image_data: Vec<u8>,
format: String,
},
/// Process base64 image
ProcessBase64 {
id: String,
base64: String,
format: String,
},
/// Batch process images
BatchProcess {
id: String,
images: Vec<Vec<u8>>,
format: String,
},
/// Terminate worker
Terminate,
}
/// Messages sent from worker to main thread
#[derive(Debug, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum WorkerResponse {
/// Worker is ready
Ready,
/// Processing started
Started { id: String },
/// Processing progress
Progress {
id: String,
processed: usize,
total: usize,
},
/// Processing completed successfully
Success {
id: String,
result: serde_json::Value,
},
/// Processing failed
Error { id: String, error: String },
/// Worker terminated
Terminated,
}
/// Initialize the worker
#[wasm_bindgen(js_name = initWorker)]
pub async fn init_worker() -> Result<(), JsValue> {
let instance = ScipixWasm::new().await?;
WORKER_INSTANCE
.set(Arc::new(instance))
.map_err(|_| JsValue::from_str("Worker already initialized"))?;
post_response(WorkerResponse::Ready)?;
Ok(())
}
/// Handle messages from the main thread
#[wasm_bindgen(js_name = handleWorkerMessage)]
pub async fn handle_worker_message(event: MessageEvent) -> Result<(), JsValue> {
let data = event.data();
let request: WorkerRequest = serde_wasm_bindgen::from_value(data)
.map_err(|e| JsValue::from_str(&format!("Invalid message: {}", e)))?;
match request {
WorkerRequest::Init => {
init_worker().await?;
}
WorkerRequest::Process {
id,
image_data,
format,
} => {
process_image(id, image_data, format).await?;
}
WorkerRequest::ProcessBase64 { id, base64, format } => {
process_base64(id, base64, format).await?;
}
WorkerRequest::BatchProcess { id, images, format } => {
process_batch(id, images, format).await?;
}
WorkerRequest::Terminate => {
post_response(WorkerResponse::Terminated)?;
}
}
Ok(())
}
async fn process_image(id: String, image_data: Vec<u8>, format: String) -> Result<(), JsValue> {
post_response(WorkerResponse::Started { id: id.clone() })?;
let instance = WORKER_INSTANCE
.get()
.ok_or_else(|| JsValue::from_str("Worker not initialized"))?;
let mut worker_instance = ScipixWasm::new().await?;
worker_instance.set_format(&format);
match worker_instance.recognize(&image_data).await {
Ok(result) => {
let json_result: serde_json::Value = serde_wasm_bindgen::from_value(result)?;
post_response(WorkerResponse::Success {
id,
result: json_result,
})?;
}
Err(e) => {
post_response(WorkerResponse::Error {
id,
error: format!("{:?}", e),
})?;
}
}
Ok(())
}
async fn process_base64(id: String, base64: String, format: String) -> Result<(), JsValue> {
post_response(WorkerResponse::Started { id: id.clone() })?;
let mut worker_instance = ScipixWasm::new().await?;
worker_instance.set_format(&format);
match worker_instance.recognize_base64(&base64).await {
Ok(result) => {
let json_result: serde_json::Value = serde_wasm_bindgen::from_value(result)?;
post_response(WorkerResponse::Success {
id,
result: json_result,
})?;
}
Err(e) => {
post_response(WorkerResponse::Error {
id,
error: format!("{:?}", e),
})?;
}
}
Ok(())
}
async fn process_batch(id: String, images: Vec<Vec<u8>>, format: String) -> Result<(), JsValue> {
post_response(WorkerResponse::Started { id: id.clone() })?;
let total = images.len();
let mut results = Vec::new();
let mut worker_instance = ScipixWasm::new().await?;
worker_instance.set_format(&format);
for (idx, image_data) in images.into_iter().enumerate() {
// Report progress
post_response(WorkerResponse::Progress {
id: id.clone(),
processed: idx,
total,
})?;
match worker_instance.recognize(&image_data).await {
Ok(result) => {
let json_result: serde_json::Value = serde_wasm_bindgen::from_value(result)?;
results.push(json_result);
}
Err(e) => {
web_sys::console::warn_1(&JsValue::from_str(&format!(
"Failed to process image {}: {:?}",
idx, e
)));
results.push(serde_json::Value::Null);
}
}
}
post_response(WorkerResponse::Success {
id,
result: serde_json::json!({ "results": results }),
})?;
Ok(())
}
fn post_response(response: WorkerResponse) -> Result<(), JsValue> {
let global = js_sys::global().dyn_into::<DedicatedWorkerGlobalScope>()?;
let message = serde_wasm_bindgen::to_value(&response)?;
global.post_message(&message)?;
Ok(())
}
/// Setup worker message listener
#[wasm_bindgen(js_name = setupWorker)]
pub fn setup_worker() -> Result<(), JsValue> {
let global = js_sys::global().dyn_into::<DedicatedWorkerGlobalScope>()?;
let closure = Closure::wrap(Box::new(move |event: MessageEvent| {
wasm_bindgen_futures::spawn_local(async move {
if let Err(e) = handle_worker_message(event).await {
web_sys::console::error_1(&e);
}
});
}) as Box<dyn FnMut(MessageEvent)>);
global.set_onmessage(Some(closure.as_ref().unchecked_ref()));
closure.forget(); // Keep closure alive
Ok(())
}