Squashed 'vendor/ruvector/' content from commit b64c2172
git-subtree-dir: vendor/ruvector git-subtree-split: b64c21726f2bb37286d9ee36a7869fef60cc6900
This commit is contained in:
308
examples/scipix/src/api/handlers.rs
Normal file
308
examples/scipix/src/api/handlers.rs
Normal file
@@ -0,0 +1,308 @@
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
http::StatusCode,
|
||||
response::{sse::Event, IntoResponse, Sse},
|
||||
Json,
|
||||
};
|
||||
use futures::stream::{self, Stream};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{convert::Infallible, time::Duration};
|
||||
use tracing::{error, info, warn};
|
||||
use validator::Validate;
|
||||
|
||||
use super::{
|
||||
jobs::{JobStatus, PdfJob},
|
||||
requests::{LatexRequest, PdfRequest, StrokesRequest, TextRequest},
|
||||
responses::{ErrorResponse, PdfResponse, TextResponse},
|
||||
state::AppState,
|
||||
};
|
||||
|
||||
/// Health check handler
|
||||
pub async fn get_health() -> impl IntoResponse {
|
||||
#[derive(Serialize)]
|
||||
struct Health {
|
||||
status: &'static str,
|
||||
version: &'static str,
|
||||
}
|
||||
|
||||
Json(Health {
|
||||
status: "ok",
|
||||
version: env!("CARGO_PKG_VERSION"),
|
||||
})
|
||||
}
|
||||
|
||||
/// Process text/image OCR request
|
||||
/// Supports multipart/form-data, base64, and URL inputs
|
||||
///
|
||||
/// # Important
|
||||
/// This endpoint requires OCR models to be configured. If models are not available,
|
||||
/// returns a 503 Service Unavailable error with instructions.
|
||||
pub async fn process_text(
|
||||
State(_state): State<AppState>,
|
||||
Json(request): Json<TextRequest>,
|
||||
) -> Result<Json<TextResponse>, ErrorResponse> {
|
||||
info!("Processing text OCR request");
|
||||
|
||||
// Validate request
|
||||
request.validate().map_err(|e| {
|
||||
warn!("Invalid request: {:?}", e);
|
||||
ErrorResponse::validation_error(format!("Validation failed: {}", e))
|
||||
})?;
|
||||
|
||||
// Download or decode image
|
||||
let image_data = match request.get_image_data().await {
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
error!("Failed to get image data: {:?}", e);
|
||||
return Err(ErrorResponse::internal_error("Failed to process image"));
|
||||
}
|
||||
};
|
||||
|
||||
// Validate image data is not empty
|
||||
if image_data.is_empty() {
|
||||
return Err(ErrorResponse::validation_error("Image data is empty"));
|
||||
}
|
||||
|
||||
// OCR processing requires models to be configured
|
||||
// Return informative error explaining how to set up the service
|
||||
Err(ErrorResponse::service_unavailable(
|
||||
"OCR service not fully configured. ONNX models are required for OCR processing. \
|
||||
Please download compatible models (PaddleOCR, TrOCR) and configure the model directory. \
|
||||
See documentation at /docs/MODEL_SETUP.md for setup instructions.",
|
||||
))
|
||||
}
|
||||
|
||||
/// Process digital ink strokes
|
||||
///
|
||||
/// # Important
|
||||
/// This endpoint requires OCR models to be configured.
|
||||
pub async fn process_strokes(
|
||||
State(_state): State<AppState>,
|
||||
Json(request): Json<StrokesRequest>,
|
||||
) -> Result<Json<TextResponse>, ErrorResponse> {
|
||||
info!(
|
||||
"Processing strokes request with {} strokes",
|
||||
request.strokes.len()
|
||||
);
|
||||
|
||||
request
|
||||
.validate()
|
||||
.map_err(|e| ErrorResponse::validation_error(format!("Validation failed: {}", e)))?;
|
||||
|
||||
// Validate we have stroke data
|
||||
if request.strokes.is_empty() {
|
||||
return Err(ErrorResponse::validation_error("No strokes provided"));
|
||||
}
|
||||
|
||||
// Stroke recognition requires models to be configured
|
||||
Err(ErrorResponse::service_unavailable(
|
||||
"Stroke recognition service not configured. ONNX models required for ink recognition.",
|
||||
))
|
||||
}
|
||||
|
||||
/// Process legacy LaTeX equation request
|
||||
///
|
||||
/// # Important
|
||||
/// This endpoint requires OCR models to be configured.
|
||||
pub async fn process_latex(
|
||||
State(_state): State<AppState>,
|
||||
Json(request): Json<LatexRequest>,
|
||||
) -> Result<Json<TextResponse>, ErrorResponse> {
|
||||
info!("Processing legacy LaTeX request");
|
||||
|
||||
request
|
||||
.validate()
|
||||
.map_err(|e| ErrorResponse::validation_error(format!("Validation failed: {}", e)))?;
|
||||
|
||||
// LaTeX recognition requires models to be configured
|
||||
Err(ErrorResponse::service_unavailable(
|
||||
"LaTeX recognition service not configured. ONNX models required.",
|
||||
))
|
||||
}
|
||||
|
||||
/// Create async PDF processing job
|
||||
pub async fn process_pdf(
|
||||
State(state): State<AppState>,
|
||||
Json(request): Json<PdfRequest>,
|
||||
) -> Result<Json<PdfResponse>, ErrorResponse> {
|
||||
info!("Creating PDF processing job");
|
||||
|
||||
request
|
||||
.validate()
|
||||
.map_err(|e| ErrorResponse::validation_error(format!("Validation failed: {}", e)))?;
|
||||
|
||||
// Create job
|
||||
let job = PdfJob::new(request);
|
||||
let job_id = job.id.clone();
|
||||
|
||||
// Queue job
|
||||
state.job_queue.enqueue(job).await.map_err(|e| {
|
||||
error!("Failed to enqueue job: {:?}", e);
|
||||
ErrorResponse::internal_error("Failed to create PDF job")
|
||||
})?;
|
||||
|
||||
let response = PdfResponse {
|
||||
pdf_id: job_id,
|
||||
status: JobStatus::Processing,
|
||||
message: Some("PDF processing started".to_string()),
|
||||
result: None,
|
||||
error: None,
|
||||
};
|
||||
|
||||
Ok(Json(response))
|
||||
}
|
||||
|
||||
/// Get PDF job status
|
||||
pub async fn get_pdf_status(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<Json<PdfResponse>, ErrorResponse> {
|
||||
info!("Getting PDF job status: {}", id);
|
||||
|
||||
let status = state
|
||||
.job_queue
|
||||
.get_status(&id)
|
||||
.await
|
||||
.ok_or_else(|| ErrorResponse::not_found("Job not found"))?;
|
||||
|
||||
let response = PdfResponse {
|
||||
pdf_id: id.clone(),
|
||||
status: status.clone(),
|
||||
message: Some(format!("Job status: {:?}", status)),
|
||||
result: state.job_queue.get_result(&id).await,
|
||||
error: state.job_queue.get_error(&id).await,
|
||||
};
|
||||
|
||||
Ok(Json(response))
|
||||
}
|
||||
|
||||
/// Delete PDF job
|
||||
pub async fn delete_pdf_job(
|
||||
State(state): State<AppState>,
|
||||
Path(id): Path<String>,
|
||||
) -> Result<StatusCode, ErrorResponse> {
|
||||
info!("Deleting PDF job: {}", id);
|
||||
|
||||
state
|
||||
.job_queue
|
||||
.cancel(&id)
|
||||
.await
|
||||
.map_err(|_| ErrorResponse::not_found("Job not found"))?;
|
||||
|
||||
Ok(StatusCode::NO_CONTENT)
|
||||
}
|
||||
|
||||
/// Stream PDF processing results via SSE
|
||||
pub async fn stream_pdf_results(
|
||||
State(_state): State<AppState>,
|
||||
Path(_id): Path<String>,
|
||||
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
|
||||
info!("Streaming PDF results for job: {}", _id);
|
||||
|
||||
let stream = stream::unfold(0, move |page| {
|
||||
async move {
|
||||
if page > 10 {
|
||||
// Example: stop after 10 pages
|
||||
return None;
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
|
||||
let event = Event::default()
|
||||
.json_data(serde_json::json!({
|
||||
"page": page,
|
||||
"text": format!("Content from page {}", page),
|
||||
"progress": (page as f32 / 10.0) * 100.0
|
||||
}))
|
||||
.ok()?;
|
||||
|
||||
Some((Ok(event), page + 1))
|
||||
}
|
||||
});
|
||||
|
||||
Sse::new(stream)
|
||||
}
|
||||
|
||||
/// Convert document to different format (MMD/DOCX/etc)
|
||||
///
|
||||
/// # Note
|
||||
/// Document conversion requires additional backend services to be configured.
|
||||
pub async fn convert_document(
|
||||
State(_state): State<AppState>,
|
||||
Json(_request): Json<serde_json::Value>,
|
||||
) -> Result<Json<serde_json::Value>, ErrorResponse> {
|
||||
info!("Converting document");
|
||||
|
||||
// Document conversion is not yet implemented
|
||||
Err(ErrorResponse::not_implemented(
|
||||
"Document conversion is not yet implemented. This feature requires additional backend services."
|
||||
))
|
||||
}
|
||||
|
||||
/// Get OCR processing history
|
||||
#[derive(Deserialize)]
|
||||
pub struct HistoryQuery {
|
||||
#[serde(default)]
|
||||
page: u32,
|
||||
#[serde(default = "default_limit")]
|
||||
limit: u32,
|
||||
}
|
||||
|
||||
fn default_limit() -> u32 {
|
||||
50
|
||||
}
|
||||
|
||||
/// Get OCR processing history
|
||||
///
|
||||
/// # Note
|
||||
/// History storage requires a database backend to be configured.
|
||||
/// Returns empty results if no database is available.
|
||||
pub async fn get_ocr_results(
|
||||
State(_state): State<AppState>,
|
||||
Query(params): Query<HistoryQuery>,
|
||||
) -> Result<Json<serde_json::Value>, ErrorResponse> {
|
||||
info!(
|
||||
"Getting OCR results history: page={}, limit={}",
|
||||
params.page, params.limit
|
||||
);
|
||||
|
||||
// History storage not configured - return empty results with notice
|
||||
Ok(Json(serde_json::json!({
|
||||
"results": [],
|
||||
"total": 0,
|
||||
"page": params.page,
|
||||
"limit": params.limit,
|
||||
"notice": "History storage not configured. Results are not persisted."
|
||||
})))
|
||||
}
|
||||
|
||||
/// Get OCR usage statistics
|
||||
///
|
||||
/// # Note
|
||||
/// Usage tracking requires a database backend to be configured.
|
||||
/// Returns zeros if no database is available.
|
||||
pub async fn get_ocr_usage(
|
||||
State(_state): State<AppState>,
|
||||
) -> Result<Json<serde_json::Value>, ErrorResponse> {
|
||||
info!("Getting OCR usage statistics");
|
||||
|
||||
// Usage tracking not configured - return zeros with notice
|
||||
Ok(Json(serde_json::json!({
|
||||
"requests_today": 0,
|
||||
"requests_month": 0,
|
||||
"quota_limit": null,
|
||||
"quota_remaining": null,
|
||||
"notice": "Usage tracking not configured. Statistics are not recorded."
|
||||
})))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_health_check() {
|
||||
let response = get_health().await.into_response();
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
}
|
||||
}
|
||||
281
examples/scipix/src/api/jobs.rs
Normal file
281
examples/scipix/src/api/jobs.rs
Normal file
@@ -0,0 +1,281 @@
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::requests::PdfRequest;
|
||||
|
||||
/// Job status enumeration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum JobStatus {
|
||||
/// Job is queued but not started
|
||||
Queued,
|
||||
/// Job is currently processing
|
||||
Processing,
|
||||
/// Job completed successfully
|
||||
Completed,
|
||||
/// Job failed with error
|
||||
Failed,
|
||||
/// Job was cancelled
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
/// PDF processing job
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PdfJob {
|
||||
/// Unique job identifier
|
||||
pub id: String,
|
||||
|
||||
/// Original request
|
||||
pub request: PdfRequest,
|
||||
|
||||
/// Current status
|
||||
pub status: JobStatus,
|
||||
|
||||
/// Creation timestamp
|
||||
pub created_at: DateTime<Utc>,
|
||||
|
||||
/// Last update timestamp
|
||||
pub updated_at: DateTime<Utc>,
|
||||
|
||||
/// Processing result
|
||||
pub result: Option<String>,
|
||||
|
||||
/// Error message (if failed)
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
impl PdfJob {
|
||||
/// Create a new PDF job
|
||||
pub fn new(request: PdfRequest) -> Self {
|
||||
let now = Utc::now();
|
||||
Self {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
request,
|
||||
status: JobStatus::Queued,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
result: None,
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update job status
|
||||
pub fn update_status(&mut self, status: JobStatus) {
|
||||
self.status = status;
|
||||
self.updated_at = Utc::now();
|
||||
}
|
||||
|
||||
/// Set job result
|
||||
pub fn set_result(&mut self, result: String) {
|
||||
self.result = Some(result);
|
||||
self.status = JobStatus::Completed;
|
||||
self.updated_at = Utc::now();
|
||||
}
|
||||
|
||||
/// Set job error
|
||||
pub fn set_error(&mut self, error: String) {
|
||||
self.error = Some(error);
|
||||
self.status = JobStatus::Failed;
|
||||
self.updated_at = Utc::now();
|
||||
}
|
||||
}
|
||||
|
||||
/// Async job queue with webhook support
|
||||
pub struct JobQueue {
|
||||
/// Job storage
|
||||
jobs: Arc<RwLock<HashMap<String, PdfJob>>>,
|
||||
|
||||
/// Job submission channel
|
||||
tx: mpsc::Sender<PdfJob>,
|
||||
|
||||
/// Job processing handle
|
||||
_handle: Option<tokio::task::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl JobQueue {
|
||||
/// Create a new job queue
|
||||
pub fn new() -> Self {
|
||||
Self::with_capacity(1000)
|
||||
}
|
||||
|
||||
/// Create a job queue with specific capacity
|
||||
pub fn with_capacity(capacity: usize) -> Self {
|
||||
let jobs = Arc::new(RwLock::new(HashMap::new()));
|
||||
let (tx, rx) = mpsc::channel(capacity);
|
||||
|
||||
let queue_jobs = jobs.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
Self::process_jobs(queue_jobs, rx).await;
|
||||
});
|
||||
|
||||
Self {
|
||||
jobs,
|
||||
tx,
|
||||
_handle: Some(handle),
|
||||
}
|
||||
}
|
||||
|
||||
/// Enqueue a new job
|
||||
pub async fn enqueue(&self, mut job: PdfJob) -> anyhow::Result<()> {
|
||||
job.update_status(JobStatus::Queued);
|
||||
|
||||
// Store job
|
||||
{
|
||||
let mut jobs = self.jobs.write().await;
|
||||
jobs.insert(job.id.clone(), job.clone());
|
||||
}
|
||||
|
||||
// Send to processing queue
|
||||
self.tx.send(job).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get job status
|
||||
pub async fn get_status(&self, id: &str) -> Option<JobStatus> {
|
||||
let jobs = self.jobs.read().await;
|
||||
jobs.get(id).map(|job| job.status.clone())
|
||||
}
|
||||
|
||||
/// Get job result
|
||||
pub async fn get_result(&self, id: &str) -> Option<String> {
|
||||
let jobs = self.jobs.read().await;
|
||||
jobs.get(id).and_then(|job| job.result.clone())
|
||||
}
|
||||
|
||||
/// Get job error
|
||||
pub async fn get_error(&self, id: &str) -> Option<String> {
|
||||
let jobs = self.jobs.read().await;
|
||||
jobs.get(id).and_then(|job| job.error.clone())
|
||||
}
|
||||
|
||||
/// Cancel a job
|
||||
pub async fn cancel(&self, id: &str) -> anyhow::Result<()> {
|
||||
let mut jobs = self.jobs.write().await;
|
||||
if let Some(job) = jobs.get_mut(id) {
|
||||
job.update_status(JobStatus::Cancelled);
|
||||
Ok(())
|
||||
} else {
|
||||
anyhow::bail!("Job not found")
|
||||
}
|
||||
}
|
||||
|
||||
/// Background job processor
|
||||
async fn process_jobs(
|
||||
jobs: Arc<RwLock<HashMap<String, PdfJob>>>,
|
||||
mut rx: mpsc::Receiver<PdfJob>,
|
||||
) {
|
||||
while let Some(job) = rx.recv().await {
|
||||
let job_id = job.id.clone();
|
||||
|
||||
// Update status to processing
|
||||
{
|
||||
let mut jobs_lock = jobs.write().await;
|
||||
if let Some(stored_job) = jobs_lock.get_mut(&job_id) {
|
||||
stored_job.update_status(JobStatus::Processing);
|
||||
}
|
||||
}
|
||||
|
||||
// Simulate PDF processing
|
||||
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
|
||||
|
||||
// Update with result
|
||||
{
|
||||
let mut jobs_lock = jobs.write().await;
|
||||
if let Some(stored_job) = jobs_lock.get_mut(&job_id) {
|
||||
stored_job.set_result("Processed PDF content".to_string());
|
||||
|
||||
// Send webhook if specified
|
||||
if let Some(webhook_url) = &stored_job.request.webhook_url {
|
||||
Self::send_webhook(webhook_url, stored_job).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Send webhook notification
|
||||
async fn send_webhook(url: &str, job: &PdfJob) {
|
||||
let client = reqwest::Client::new();
|
||||
let payload = serde_json::json!({
|
||||
"job_id": job.id,
|
||||
"status": job.status,
|
||||
"result": job.result,
|
||||
"error": job.error,
|
||||
});
|
||||
|
||||
if let Err(e) = client.post(url).json(&payload).send().await {
|
||||
tracing::error!("Failed to send webhook to {}: {:?}", url, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for JobQueue {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::api::requests::{PdfOptions, RequestMetadata};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_job_creation() {
|
||||
let request = PdfRequest {
|
||||
url: "https://example.com/test.pdf".to_string(),
|
||||
options: PdfOptions::default(),
|
||||
webhook_url: None,
|
||||
metadata: RequestMetadata::default(),
|
||||
};
|
||||
|
||||
let job = PdfJob::new(request);
|
||||
assert_eq!(job.status, JobStatus::Queued);
|
||||
assert!(job.result.is_none());
|
||||
assert!(job.error.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_job_queue_enqueue() {
|
||||
let queue = JobQueue::new();
|
||||
let request = PdfRequest {
|
||||
url: "https://example.com/test.pdf".to_string(),
|
||||
options: PdfOptions::default(),
|
||||
webhook_url: None,
|
||||
metadata: RequestMetadata::default(),
|
||||
};
|
||||
|
||||
let job = PdfJob::new(request);
|
||||
let job_id = job.id.clone();
|
||||
|
||||
queue.enqueue(job).await.unwrap();
|
||||
|
||||
let status = queue.get_status(&job_id).await;
|
||||
assert!(status.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_job_cancellation() {
|
||||
let queue = JobQueue::new();
|
||||
let request = PdfRequest {
|
||||
url: "https://example.com/test.pdf".to_string(),
|
||||
options: PdfOptions::default(),
|
||||
webhook_url: None,
|
||||
metadata: RequestMetadata::default(),
|
||||
};
|
||||
|
||||
let job = PdfJob::new(request);
|
||||
let job_id = job.id.clone();
|
||||
|
||||
queue.enqueue(job).await.unwrap();
|
||||
queue.cancel(&job_id).await.unwrap();
|
||||
|
||||
let status = queue.get_status(&job_id).await;
|
||||
assert_eq!(status, Some(JobStatus::Cancelled));
|
||||
}
|
||||
}
|
||||
197
examples/scipix/src/api/middleware.rs
Normal file
197
examples/scipix/src/api/middleware.rs
Normal file
@@ -0,0 +1,197 @@
|
||||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::HeaderMap,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use governor::{
|
||||
clock::DefaultClock,
|
||||
state::{InMemoryState, NotKeyed},
|
||||
Quota, RateLimiter,
|
||||
};
|
||||
use nonzero_ext::nonzero;
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use super::{responses::ErrorResponse, state::AppState};
|
||||
|
||||
/// Authentication middleware
|
||||
/// Validates app_id and app_key from headers or query parameters
|
||||
pub async fn auth_middleware(
|
||||
State(state): State<AppState>,
|
||||
headers: HeaderMap,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, ErrorResponse> {
|
||||
// Check if authentication is enabled
|
||||
if !state.auth_enabled {
|
||||
debug!("Authentication disabled, allowing request");
|
||||
return Ok(next.run(request).await);
|
||||
}
|
||||
|
||||
// Extract credentials from headers
|
||||
let app_id = headers
|
||||
.get("app_id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.or_else(|| {
|
||||
// Fallback to query parameters
|
||||
request
|
||||
.uri()
|
||||
.query()
|
||||
.and_then(|q| extract_query_param(q, "app_id"))
|
||||
});
|
||||
|
||||
let app_key = headers
|
||||
.get("app_key")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.or_else(|| {
|
||||
request
|
||||
.uri()
|
||||
.query()
|
||||
.and_then(|q| extract_query_param(q, "app_key"))
|
||||
});
|
||||
|
||||
// Validate credentials
|
||||
match (app_id, app_key) {
|
||||
(Some(id), Some(key)) => {
|
||||
if validate_credentials(&state, id, key).await {
|
||||
debug!("Authentication successful for app_id: {}", id);
|
||||
Ok(next.run(request).await)
|
||||
} else {
|
||||
warn!("Invalid credentials for app_id: {}", id);
|
||||
Err(ErrorResponse::unauthorized("Invalid credentials"))
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
warn!("Missing authentication credentials");
|
||||
Err(ErrorResponse::unauthorized("Missing app_id or app_key"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Rate limiting middleware using token bucket algorithm
|
||||
pub async fn rate_limit_middleware(
|
||||
State(state): State<AppState>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Result<Response, ErrorResponse> {
|
||||
// Check rate limit
|
||||
match state.rate_limiter.check() {
|
||||
Ok(_) => {
|
||||
debug!("Rate limit check passed");
|
||||
Ok(next.run(request).await)
|
||||
}
|
||||
Err(_) => {
|
||||
warn!("Rate limit exceeded");
|
||||
Err(ErrorResponse::rate_limited(
|
||||
"Rate limit exceeded. Please try again later.",
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate app credentials using secure comparison
|
||||
///
|
||||
/// SECURITY: This implementation:
|
||||
/// 1. Requires credentials to be pre-configured in AppState
|
||||
/// 2. Uses constant-time comparison to prevent timing attacks
|
||||
/// 3. Hashes the key before comparison
|
||||
async fn validate_credentials(state: &AppState, app_id: &str, app_key: &str) -> bool {
|
||||
// Reject empty credentials
|
||||
if app_id.is_empty() || app_key.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Get configured credentials from state
|
||||
let Some(expected_key_hash) = state.api_keys.get(app_id) else {
|
||||
warn!("Unknown app_id attempted authentication: {}", app_id);
|
||||
return false;
|
||||
};
|
||||
|
||||
// Hash the provided key
|
||||
let provided_key_hash = hash_api_key(app_key);
|
||||
|
||||
// Constant-time comparison to prevent timing attacks
|
||||
constant_time_compare(&provided_key_hash, expected_key_hash.as_str())
|
||||
}
|
||||
|
||||
/// Hash an API key using SHA-256
|
||||
fn hash_api_key(key: &str) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(key.as_bytes());
|
||||
format!("{:x}", hasher.finalize())
|
||||
}
|
||||
|
||||
/// Constant-time string comparison to prevent timing attacks
|
||||
fn constant_time_compare(a: &str, b: &str) -> bool {
|
||||
if a.len() != b.len() {
|
||||
return false;
|
||||
}
|
||||
|
||||
let mut result = 0u8;
|
||||
for (x, y) in a.bytes().zip(b.bytes()) {
|
||||
result |= x ^ y;
|
||||
}
|
||||
result == 0
|
||||
}
|
||||
|
||||
/// Extract query parameter from query string
|
||||
fn extract_query_param<'a>(query: &'a str, param: &str) -> Option<&'a str> {
|
||||
query.split('&').find_map(|pair| {
|
||||
let mut parts = pair.split('=');
|
||||
match (parts.next(), parts.next()) {
|
||||
(Some(k), Some(v)) if k == param => Some(v),
|
||||
_ => None,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Create a rate limiter with token bucket algorithm
|
||||
pub fn create_rate_limiter() -> Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>> {
|
||||
// Allow 100 requests per minute
|
||||
let quota = Quota::per_minute(nonzero!(100u32));
|
||||
Arc::new(RateLimiter::direct(quota))
|
||||
}
|
||||
|
||||
/// Type alias for rate limiter
|
||||
pub type AppRateLimiter = Arc<RateLimiter<NotKeyed, InMemoryState, DefaultClock>>;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_extract_query_param() {
|
||||
let query = "app_id=123&app_key=secret&foo=bar";
|
||||
assert_eq!(extract_query_param(query, "app_id"), Some("123"));
|
||||
assert_eq!(extract_query_param(query, "app_key"), Some("secret"));
|
||||
assert_eq!(extract_query_param(query, "foo"), Some("bar"));
|
||||
assert_eq!(extract_query_param(query, "missing"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hash_api_key() {
|
||||
let key = "test_key_123";
|
||||
let hash1 = hash_api_key(key);
|
||||
let hash2 = hash_api_key(key);
|
||||
assert_eq!(hash1, hash2);
|
||||
assert_ne!(hash_api_key("different"), hash1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_constant_time_compare() {
|
||||
assert!(constant_time_compare("abc", "abc"));
|
||||
assert!(!constant_time_compare("abc", "abd"));
|
||||
assert!(!constant_time_compare("abc", "ab"));
|
||||
assert!(!constant_time_compare("", "a"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_validate_credentials_rejects_empty() {
|
||||
let state = AppState::new();
|
||||
assert!(!validate_credentials(&state, "", "key").await);
|
||||
assert!(!validate_credentials(&state, "test", "").await);
|
||||
assert!(!validate_credentials(&state, "", "").await);
|
||||
}
|
||||
}
|
||||
91
examples/scipix/src/api/mod.rs
Normal file
91
examples/scipix/src/api/mod.rs
Normal file
@@ -0,0 +1,91 @@
|
||||
pub mod handlers;
|
||||
pub mod jobs;
|
||||
pub mod middleware;
|
||||
pub mod requests;
|
||||
pub mod responses;
|
||||
pub mod routes;
|
||||
pub mod state;
|
||||
|
||||
use anyhow::Result;
|
||||
use axum::Router;
|
||||
use std::net::SocketAddr;
|
||||
use tokio::signal;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use self::state::AppState;
|
||||
|
||||
/// Main API server structure
|
||||
pub struct ApiServer {
|
||||
state: AppState,
|
||||
addr: SocketAddr,
|
||||
}
|
||||
|
||||
impl ApiServer {
|
||||
/// Create a new API server instance
|
||||
pub fn new(state: AppState, addr: SocketAddr) -> Self {
|
||||
Self { state, addr }
|
||||
}
|
||||
|
||||
/// Start the API server with graceful shutdown
|
||||
pub async fn start(self) -> Result<()> {
|
||||
let app = self.create_router();
|
||||
|
||||
info!("Starting Scipix API server on {}", self.addr);
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(self.addr).await?;
|
||||
|
||||
axum::serve(listener, app)
|
||||
.with_graceful_shutdown(shutdown_signal())
|
||||
.await?;
|
||||
|
||||
info!("Server shutdown complete");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create the application router with all routes and middleware
|
||||
fn create_router(&self) -> Router {
|
||||
routes::router(self.state.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Graceful shutdown signal handler
|
||||
async fn shutdown_signal() {
|
||||
let ctrl_c = async {
|
||||
signal::ctrl_c()
|
||||
.await
|
||||
.expect("failed to install Ctrl+C handler");
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
let terminate = async {
|
||||
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||
.expect("failed to install signal handler")
|
||||
.recv()
|
||||
.await;
|
||||
};
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let terminate = std::future::pending::<()>();
|
||||
|
||||
tokio::select! {
|
||||
_ = ctrl_c => {
|
||||
warn!("Received Ctrl+C, shutting down...");
|
||||
},
|
||||
_ = terminate => {
|
||||
warn!("Received terminate signal, shutting down...");
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_server_creation() {
|
||||
let state = AppState::new();
|
||||
let addr = "127.0.0.1:3000".parse().unwrap();
|
||||
let server = ApiServer::new(state, addr);
|
||||
assert_eq!(server.addr, addr);
|
||||
}
|
||||
}
|
||||
227
examples/scipix/src/api/requests.rs
Normal file
227
examples/scipix/src/api/requests.rs
Normal file
@@ -0,0 +1,227 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use validator::Validate;
|
||||
|
||||
/// Text/Image OCR request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
|
||||
pub struct TextRequest {
|
||||
/// Image source (base64, URL, or multipart)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub src: Option<String>,
|
||||
|
||||
/// Base64 encoded image data
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub base64: Option<String>,
|
||||
|
||||
/// Image URL
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[validate(url)]
|
||||
pub url: Option<String>,
|
||||
|
||||
/// Request metadata
|
||||
#[serde(default)]
|
||||
pub metadata: RequestMetadata,
|
||||
}
|
||||
|
||||
impl TextRequest {
|
||||
/// Get image data from request
|
||||
pub async fn get_image_data(&self) -> anyhow::Result<Vec<u8>> {
|
||||
if let Some(base64_data) = &self.base64 {
|
||||
// Decode base64
|
||||
use base64::Engine;
|
||||
let decoded = base64::engine::general_purpose::STANDARD.decode(base64_data)?;
|
||||
Ok(decoded)
|
||||
} else if let Some(url) = &self.url {
|
||||
// Download from URL
|
||||
let response = reqwest::get(url).await?;
|
||||
let bytes = response.bytes().await?;
|
||||
Ok(bytes.to_vec())
|
||||
} else {
|
||||
anyhow::bail!("No image data provided")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Digital ink strokes request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
|
||||
pub struct StrokesRequest {
|
||||
/// Array of stroke data
|
||||
#[validate(length(min = 1))]
|
||||
pub strokes: Vec<Stroke>,
|
||||
|
||||
/// Request metadata
|
||||
#[serde(default)]
|
||||
pub metadata: RequestMetadata,
|
||||
}
|
||||
|
||||
/// Single stroke in digital ink
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Stroke {
|
||||
/// X coordinates
|
||||
pub x: Vec<f64>,
|
||||
|
||||
/// Y coordinates
|
||||
pub y: Vec<f64>,
|
||||
|
||||
/// Optional timestamps
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub t: Option<Vec<f64>>,
|
||||
}
|
||||
|
||||
/// Legacy LaTeX equation request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
|
||||
pub struct LatexRequest {
|
||||
/// Image source
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub src: Option<String>,
|
||||
|
||||
/// Base64 encoded image
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub base64: Option<String>,
|
||||
|
||||
/// Image URL
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[validate(url)]
|
||||
pub url: Option<String>,
|
||||
|
||||
/// Request metadata
|
||||
#[serde(default)]
|
||||
pub metadata: RequestMetadata,
|
||||
}
|
||||
|
||||
/// PDF processing request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Validate)]
|
||||
pub struct PdfRequest {
|
||||
/// PDF file URL
|
||||
#[validate(url)]
|
||||
pub url: String,
|
||||
|
||||
/// Conversion options
|
||||
#[serde(default)]
|
||||
pub options: PdfOptions,
|
||||
|
||||
/// Webhook URL for completion notification
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[validate(url)]
|
||||
pub webhook_url: Option<String>,
|
||||
|
||||
/// Request metadata
|
||||
#[serde(default)]
|
||||
pub metadata: RequestMetadata,
|
||||
}
|
||||
|
||||
/// PDF processing options
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct PdfOptions {
|
||||
/// Output format
|
||||
#[serde(default = "default_format")]
|
||||
pub format: String,
|
||||
|
||||
/// Enable OCR
|
||||
#[serde(default)]
|
||||
pub enable_ocr: bool,
|
||||
|
||||
/// Include images
|
||||
#[serde(default = "default_true")]
|
||||
pub include_images: bool,
|
||||
|
||||
/// Page range (e.g., "1-5")
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub page_range: Option<String>,
|
||||
}
|
||||
|
||||
fn default_format() -> String {
|
||||
"mmd".to_string()
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Request metadata
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct RequestMetadata {
|
||||
/// Output formats
|
||||
#[serde(default = "default_formats")]
|
||||
pub formats: Vec<String>,
|
||||
|
||||
/// Include confidence scores
|
||||
#[serde(default)]
|
||||
pub include_confidence: bool,
|
||||
|
||||
/// Enable math mode
|
||||
#[serde(default = "default_true")]
|
||||
pub enable_math: bool,
|
||||
|
||||
/// Language hint
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub language: Option<String>,
|
||||
}
|
||||
|
||||
fn default_formats() -> Vec<String> {
|
||||
vec!["text".to_string()]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_text_request_validation() {
|
||||
let request = TextRequest {
|
||||
src: None,
|
||||
base64: Some("SGVsbG8gV29ybGQ=".to_string()),
|
||||
url: None,
|
||||
metadata: RequestMetadata::default(),
|
||||
};
|
||||
|
||||
assert!(request.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strokes_request_validation() {
|
||||
let request = StrokesRequest {
|
||||
strokes: vec![Stroke {
|
||||
x: vec![0.0, 1.0, 2.0],
|
||||
y: vec![0.0, 1.0, 0.0],
|
||||
t: None,
|
||||
}],
|
||||
metadata: RequestMetadata::default(),
|
||||
};
|
||||
|
||||
assert!(request.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_strokes_validation() {
|
||||
let request = StrokesRequest {
|
||||
strokes: vec![],
|
||||
metadata: RequestMetadata::default(),
|
||||
};
|
||||
|
||||
assert!(request.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pdf_request_validation() {
|
||||
let request = PdfRequest {
|
||||
url: "https://example.com/document.pdf".to_string(),
|
||||
options: PdfOptions::default(),
|
||||
webhook_url: None,
|
||||
metadata: RequestMetadata::default(),
|
||||
};
|
||||
|
||||
assert!(request.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_url() {
|
||||
let request = PdfRequest {
|
||||
url: "not-a-url".to_string(),
|
||||
options: PdfOptions::default(),
|
||||
webhook_url: None,
|
||||
metadata: RequestMetadata::default(),
|
||||
};
|
||||
|
||||
assert!(request.validate().is_err());
|
||||
}
|
||||
}
|
||||
177
examples/scipix/src/api/responses.rs
Normal file
177
examples/scipix/src/api/responses.rs
Normal file
@@ -0,0 +1,177 @@
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::jobs::JobStatus;
|
||||
|
||||
/// Standard text/OCR response
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TextResponse {
|
||||
/// Unique request identifier
|
||||
pub request_id: String,
|
||||
|
||||
/// Recognized text
|
||||
pub text: String,
|
||||
|
||||
/// Confidence score (0.0 - 1.0)
|
||||
pub confidence: f64,
|
||||
|
||||
/// LaTeX output (if requested)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub latex: Option<String>,
|
||||
|
||||
/// MathML output (if requested)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub mathml: Option<String>,
|
||||
|
||||
/// HTML output (if requested)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub html: Option<String>,
|
||||
}
|
||||
|
||||
/// PDF processing response
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PdfResponse {
|
||||
/// PDF job identifier
|
||||
pub pdf_id: String,
|
||||
|
||||
/// Current job status
|
||||
pub status: JobStatus,
|
||||
|
||||
/// Status message
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub message: Option<String>,
|
||||
|
||||
/// Processing result (when completed)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result: Option<String>,
|
||||
|
||||
/// Error details (if failed)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<String>,
|
||||
}
|
||||
|
||||
/// Error response
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ErrorResponse {
|
||||
/// Error code
|
||||
pub error_code: String,
|
||||
|
||||
/// Human-readable error message
|
||||
pub message: String,
|
||||
|
||||
/// HTTP status code
|
||||
#[serde(skip)]
|
||||
pub status: StatusCode,
|
||||
}
|
||||
|
||||
impl ErrorResponse {
|
||||
/// Create a validation error response
|
||||
pub fn validation_error(message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
error_code: "VALIDATION_ERROR".to_string(),
|
||||
message: message.into(),
|
||||
status: StatusCode::BAD_REQUEST,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an unauthorized error response
|
||||
pub fn unauthorized(message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
error_code: "UNAUTHORIZED".to_string(),
|
||||
message: message.into(),
|
||||
status: StatusCode::UNAUTHORIZED,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a not found error response
|
||||
pub fn not_found(message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
error_code: "NOT_FOUND".to_string(),
|
||||
message: message.into(),
|
||||
status: StatusCode::NOT_FOUND,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a rate limit error response
|
||||
pub fn rate_limited(message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
error_code: "RATE_LIMIT_EXCEEDED".to_string(),
|
||||
message: message.into(),
|
||||
status: StatusCode::TOO_MANY_REQUESTS,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an internal error response
|
||||
pub fn internal_error(message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
error_code: "INTERNAL_ERROR".to_string(),
|
||||
message: message.into(),
|
||||
status: StatusCode::INTERNAL_SERVER_ERROR,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a service unavailable error response
|
||||
/// Used when the service is not fully configured (e.g., missing models)
|
||||
pub fn service_unavailable(message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
error_code: "SERVICE_UNAVAILABLE".to_string(),
|
||||
message: message.into(),
|
||||
status: StatusCode::SERVICE_UNAVAILABLE,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a not implemented error response
|
||||
pub fn not_implemented(message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
error_code: "NOT_IMPLEMENTED".to_string(),
|
||||
message: message.into(),
|
||||
status: StatusCode::NOT_IMPLEMENTED,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for ErrorResponse {
|
||||
fn into_response(self) -> Response {
|
||||
let status = self.status;
|
||||
(status, Json(self)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_text_response_serialization() {
|
||||
let response = TextResponse {
|
||||
request_id: "test-123".to_string(),
|
||||
text: "Hello World".to_string(),
|
||||
confidence: 0.95,
|
||||
latex: Some("x^2".to_string()),
|
||||
mathml: None,
|
||||
html: None,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&response).unwrap();
|
||||
assert!(json.contains("request_id"));
|
||||
assert!(json.contains("test-123"));
|
||||
assert!(!json.contains("mathml"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_response_creation() {
|
||||
let error = ErrorResponse::validation_error("Invalid input");
|
||||
assert_eq!(error.status, StatusCode::BAD_REQUEST);
|
||||
assert_eq!(error.error_code, "VALIDATION_ERROR");
|
||||
|
||||
let error = ErrorResponse::unauthorized("Invalid credentials");
|
||||
assert_eq!(error.status, StatusCode::UNAUTHORIZED);
|
||||
|
||||
let error = ErrorResponse::rate_limited("Too many requests");
|
||||
assert_eq!(error.status, StatusCode::TOO_MANY_REQUESTS);
|
||||
}
|
||||
}
|
||||
103
examples/scipix/src/api/routes.rs
Normal file
103
examples/scipix/src/api/routes.rs
Normal file
@@ -0,0 +1,103 @@
|
||||
use axum::{
|
||||
routing::{delete, get, post},
|
||||
Router,
|
||||
};
|
||||
use tower::ServiceBuilder;
|
||||
use tower_http::{
|
||||
compression::CompressionLayer,
|
||||
cors::CorsLayer,
|
||||
trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer},
|
||||
};
|
||||
use tracing::Level;
|
||||
|
||||
use super::{
|
||||
handlers::{
|
||||
convert_document, delete_pdf_job, get_health, get_ocr_results, get_ocr_usage,
|
||||
get_pdf_status, process_latex, process_pdf, process_strokes, process_text,
|
||||
stream_pdf_results,
|
||||
},
|
||||
middleware::{auth_middleware, rate_limit_middleware},
|
||||
state::AppState,
|
||||
};
|
||||
|
||||
/// Create the main application router with all routes and middleware
|
||||
pub fn router(state: AppState) -> Router {
|
||||
// API v3 routes
|
||||
let api_routes = Router::new()
|
||||
// Image processing
|
||||
.route("/v3/text", post(process_text))
|
||||
// Digital ink processing
|
||||
.route("/v3/strokes", post(process_strokes))
|
||||
// Legacy equation processing
|
||||
.route("/v3/latex", post(process_latex))
|
||||
// Async PDF processing
|
||||
.route("/v3/pdf", post(process_pdf))
|
||||
.route("/v3/pdf/:id", get(get_pdf_status))
|
||||
.route("/v3/pdf/:id", delete(delete_pdf_job))
|
||||
.route("/v3/pdf/:id/stream", get(stream_pdf_results))
|
||||
// Document conversion
|
||||
.route("/v3/converter", post(convert_document))
|
||||
// History and usage
|
||||
.route("/v3/ocr-results", get(get_ocr_results))
|
||||
.route("/v3/ocr-usage", get(get_ocr_usage))
|
||||
// Apply auth and rate limiting to all API routes
|
||||
.layer(
|
||||
ServiceBuilder::new()
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
auth_middleware,
|
||||
))
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
rate_limit_middleware,
|
||||
)),
|
||||
);
|
||||
|
||||
// Health check (no auth required)
|
||||
let health_routes = Router::new().route("/health", get(get_health));
|
||||
|
||||
// Combine all routes
|
||||
Router::new()
|
||||
.merge(api_routes)
|
||||
.merge(health_routes)
|
||||
.layer(
|
||||
ServiceBuilder::new()
|
||||
// Tracing layer
|
||||
.layer(
|
||||
TraceLayer::new_for_http()
|
||||
.make_span_with(DefaultMakeSpan::new().level(Level::INFO))
|
||||
.on_response(DefaultOnResponse::new().level(Level::INFO)),
|
||||
)
|
||||
// CORS layer
|
||||
.layer(CorsLayer::permissive())
|
||||
// Compression layer
|
||||
.layer(CompressionLayer::new()),
|
||||
)
|
||||
.with_state(state)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::body::Body;
|
||||
use axum::http::{Request, StatusCode};
|
||||
use tower::ServiceExt;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_health_endpoint() {
|
||||
let state = AppState::new();
|
||||
let app = router(state);
|
||||
|
||||
let response = app
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.uri("/health")
|
||||
.body(Body::empty())
|
||||
.unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
}
|
||||
}
|
||||
148
examples/scipix/src/api/state.rs
Normal file
148
examples/scipix/src/api/state.rs
Normal file
@@ -0,0 +1,148 @@
|
||||
use moka::future::Cache;
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use super::{
|
||||
jobs::JobQueue,
|
||||
middleware::{create_rate_limiter, AppRateLimiter},
|
||||
};
|
||||
|
||||
/// Shared application state
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
/// Job queue for async PDF processing
|
||||
pub job_queue: Arc<JobQueue>,
|
||||
|
||||
/// Result cache
|
||||
pub cache: Cache<String, String>,
|
||||
|
||||
/// Rate limiter
|
||||
pub rate_limiter: AppRateLimiter,
|
||||
|
||||
/// Whether authentication is enabled
|
||||
pub auth_enabled: bool,
|
||||
|
||||
/// Map of app_id -> hashed API key
|
||||
/// Keys should be stored as SHA-256 hashes, never in plaintext
|
||||
pub api_keys: Arc<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
/// Create a new application state instance with authentication disabled
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
job_queue: Arc::new(JobQueue::new()),
|
||||
cache: create_cache(),
|
||||
rate_limiter: create_rate_limiter(),
|
||||
auth_enabled: false,
|
||||
api_keys: Arc::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create state with custom configuration
|
||||
pub fn with_config(max_jobs: usize, cache_size: u64) -> Self {
|
||||
Self {
|
||||
job_queue: Arc::new(JobQueue::with_capacity(max_jobs)),
|
||||
cache: Cache::builder()
|
||||
.max_capacity(cache_size)
|
||||
.time_to_live(Duration::from_secs(3600))
|
||||
.time_to_idle(Duration::from_secs(600))
|
||||
.build(),
|
||||
rate_limiter: create_rate_limiter(),
|
||||
auth_enabled: false,
|
||||
api_keys: Arc::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create state with authentication enabled
|
||||
pub fn with_auth(api_keys: HashMap<String, String>) -> Self {
|
||||
// Hash all provided API keys
|
||||
let hashed_keys: HashMap<String, String> = api_keys
|
||||
.into_iter()
|
||||
.map(|(app_id, key)| (app_id, hash_api_key(&key)))
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
job_queue: Arc::new(JobQueue::new()),
|
||||
cache: create_cache(),
|
||||
rate_limiter: create_rate_limiter(),
|
||||
auth_enabled: true,
|
||||
api_keys: Arc::new(hashed_keys),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add an API key (hashes the key before storing)
|
||||
pub fn add_api_key(&mut self, app_id: String, api_key: &str) {
|
||||
let hashed = hash_api_key(api_key);
|
||||
Arc::make_mut(&mut self.api_keys).insert(app_id, hashed);
|
||||
self.auth_enabled = true;
|
||||
}
|
||||
|
||||
/// Enable or disable authentication
|
||||
pub fn set_auth_enabled(&mut self, enabled: bool) {
|
||||
self.auth_enabled = enabled;
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AppState {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Hash an API key using SHA-256
|
||||
fn hash_api_key(key: &str) -> String {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(key.as_bytes());
|
||||
format!("{:x}", hasher.finalize())
|
||||
}
|
||||
|
||||
/// Create a cache with default configuration
|
||||
fn create_cache() -> Cache<String, String> {
|
||||
Cache::builder()
|
||||
// Max 10,000 entries
|
||||
.max_capacity(10_000)
|
||||
// Time to live: 1 hour
|
||||
.time_to_live(Duration::from_secs(3600))
|
||||
// Time to idle: 10 minutes
|
||||
.time_to_idle(Duration::from_secs(600))
|
||||
.build()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_state_creation() {
|
||||
let state = AppState::new();
|
||||
assert!(Arc::strong_count(&state.job_queue) >= 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_state_with_config() {
|
||||
let state = AppState::with_config(100, 5000);
|
||||
assert!(Arc::strong_count(&state.job_queue) >= 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_cache_operations() {
|
||||
let state = AppState::new();
|
||||
|
||||
// Insert value
|
||||
state
|
||||
.cache
|
||||
.insert("key1".to_string(), "value1".to_string())
|
||||
.await;
|
||||
|
||||
// Retrieve value
|
||||
let value = state.cache.get(&"key1".to_string()).await;
|
||||
assert_eq!(value, Some("value1".to_string()));
|
||||
|
||||
// Non-existent key
|
||||
let missing = state.cache.get(&"missing".to_string()).await;
|
||||
assert_eq!(missing, None);
|
||||
}
|
||||
}
|
||||
763
examples/scipix/src/bin/benchmark.rs
Normal file
763
examples/scipix/src/bin/benchmark.rs
Normal file
@@ -0,0 +1,763 @@
|
||||
//! SciPix OCR Benchmark Tool
|
||||
//!
|
||||
//! Comprehensive benchmark for OCR performance including:
|
||||
//! - Image preprocessing speed
|
||||
//! - Text detection throughput
|
||||
//! - Character recognition latency
|
||||
//! - End-to-end pipeline benchmarks
|
||||
|
||||
use image::{DynamicImage, ImageBuffer, Luma, Rgb, RgbImage};
|
||||
use imageproc::contrast::ThresholdType;
|
||||
use imageproc::drawing::draw_filled_rect_mut;
|
||||
use imageproc::rect::Rect;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
// Import SIMD optimizations
|
||||
use ruvector_scipix::optimize::simd::{
|
||||
fast_area_resize, simd_grayscale, simd_resize_bilinear, simd_threshold,
|
||||
};
|
||||
|
||||
/// Benchmark results
|
||||
#[derive(Debug, Clone)]
|
||||
struct BenchmarkResult {
|
||||
name: String,
|
||||
iterations: usize,
|
||||
total_time: Duration,
|
||||
avg_time: Duration,
|
||||
min_time: Duration,
|
||||
max_time: Duration,
|
||||
throughput: f64,
|
||||
}
|
||||
|
||||
impl BenchmarkResult {
|
||||
fn display(&self) {
|
||||
println!("\n{}", "=".repeat(60));
|
||||
println!("Benchmark: {}", self.name);
|
||||
println!("{}", "=".repeat(60));
|
||||
println!(" Iterations: {}", self.iterations);
|
||||
println!(" Total time: {:?}", self.total_time);
|
||||
println!(" Avg time: {:?}", self.avg_time);
|
||||
println!(" Min time: {:?}", self.min_time);
|
||||
println!(" Max time: {:?}", self.max_time);
|
||||
println!(" Throughput: {:.2} ops/sec", self.throughput);
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate a test image with synthetic patterns (simulating text)
|
||||
fn generate_test_image(width: u32, height: u32) -> RgbImage {
|
||||
let mut img: RgbImage = ImageBuffer::from_fn(width, height, |_, _| {
|
||||
Rgb([255u8, 255u8, 255u8]) // White background
|
||||
});
|
||||
|
||||
// Draw black rectangles to simulate text blocks
|
||||
for i in 0..10 {
|
||||
let x = (i * 35 + 10) as i32;
|
||||
let y = 20;
|
||||
draw_filled_rect_mut(
|
||||
&mut img,
|
||||
Rect::at(x, y).of_size(25, 40),
|
||||
Rgb([0u8, 0u8, 0u8]),
|
||||
);
|
||||
}
|
||||
|
||||
// Draw a horizontal line (like an equation fraction)
|
||||
draw_filled_rect_mut(
|
||||
&mut img,
|
||||
Rect::at(10, 70).of_size(350, 2),
|
||||
Rgb([0u8, 0u8, 0u8]),
|
||||
);
|
||||
|
||||
img
|
||||
}
|
||||
|
||||
/// Generate a math-like test image
|
||||
fn generate_math_image(width: u32, height: u32) -> RgbImage {
|
||||
let mut img: RgbImage = ImageBuffer::from_fn(width, height, |_, _| Rgb([255u8, 255u8, 255u8]));
|
||||
|
||||
// Draw elements resembling a fraction
|
||||
draw_filled_rect_mut(
|
||||
&mut img,
|
||||
Rect::at(50, 20).of_size(100, 30),
|
||||
Rgb([0u8, 0u8, 0u8]),
|
||||
);
|
||||
draw_filled_rect_mut(
|
||||
&mut img,
|
||||
Rect::at(20, 60).of_size(160, 3),
|
||||
Rgb([0u8, 0u8, 0u8]),
|
||||
);
|
||||
draw_filled_rect_mut(
|
||||
&mut img,
|
||||
Rect::at(70, 70).of_size(60, 30),
|
||||
Rgb([0u8, 0u8, 0u8]),
|
||||
);
|
||||
|
||||
// Draw square root symbol approximation
|
||||
draw_filled_rect_mut(
|
||||
&mut img,
|
||||
Rect::at(200, 30).of_size(5, 40),
|
||||
Rgb([0u8, 0u8, 0u8]),
|
||||
);
|
||||
draw_filled_rect_mut(
|
||||
&mut img,
|
||||
Rect::at(200, 30).of_size(80, 3),
|
||||
Rgb([0u8, 0u8, 0u8]),
|
||||
);
|
||||
|
||||
img
|
||||
}
|
||||
|
||||
/// Run a benchmark function multiple times and collect statistics
|
||||
fn run_benchmark<F, E>(name: &str, iterations: usize, mut f: F) -> BenchmarkResult
|
||||
where
|
||||
F: FnMut() -> Result<(), E>,
|
||||
E: std::fmt::Debug,
|
||||
{
|
||||
let mut times = Vec::with_capacity(iterations);
|
||||
|
||||
// Warmup
|
||||
for _ in 0..3 {
|
||||
let _ = f();
|
||||
}
|
||||
|
||||
// Actual benchmark
|
||||
for _ in 0..iterations {
|
||||
let start = Instant::now();
|
||||
let _ = f();
|
||||
times.push(start.elapsed());
|
||||
}
|
||||
|
||||
let total_time: Duration = times.iter().sum();
|
||||
let avg_time = total_time / iterations as u32;
|
||||
let min_time = *times.iter().min().unwrap();
|
||||
let max_time = *times.iter().max().unwrap();
|
||||
let throughput = iterations as f64 / total_time.as_secs_f64();
|
||||
|
||||
BenchmarkResult {
|
||||
name: name.to_string(),
|
||||
iterations,
|
||||
total_time,
|
||||
avg_time,
|
||||
min_time,
|
||||
max_time,
|
||||
throughput,
|
||||
}
|
||||
}
|
||||
|
||||
/// Benchmark grayscale conversion
|
||||
fn benchmark_grayscale(images: &[DynamicImage]) -> BenchmarkResult {
|
||||
let mut idx = 0;
|
||||
run_benchmark::<_, std::convert::Infallible>("Grayscale Conversion", 500, || {
|
||||
let img = &images[idx % images.len()];
|
||||
idx += 1;
|
||||
let _gray = img.to_luma8();
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Benchmark image resize
|
||||
fn benchmark_resize(images: &[DynamicImage]) -> BenchmarkResult {
|
||||
use image::imageops::FilterType;
|
||||
|
||||
let mut idx = 0;
|
||||
run_benchmark::<_, std::convert::Infallible>("Image Resize (640x480)", 100, || {
|
||||
let img = &images[idx % images.len()];
|
||||
idx += 1;
|
||||
let _resized = img.resize(640, 480, FilterType::Lanczos3);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Benchmark fast resize
|
||||
fn benchmark_fast_resize(images: &[DynamicImage]) -> BenchmarkResult {
|
||||
use image::imageops::FilterType;
|
||||
|
||||
let mut idx = 0;
|
||||
run_benchmark::<_, std::convert::Infallible>("Fast Resize (Nearest)", 500, || {
|
||||
let img = &images[idx % images.len()];
|
||||
idx += 1;
|
||||
let _resized = img.resize(640, 480, FilterType::Nearest);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Benchmark Gaussian blur
|
||||
fn benchmark_blur(images: &[DynamicImage]) -> BenchmarkResult {
|
||||
let mut idx = 0;
|
||||
run_benchmark::<_, std::convert::Infallible>("Gaussian Blur (σ=1.5)", 50, || {
|
||||
let img = &images[idx % images.len()];
|
||||
idx += 1;
|
||||
let gray = img.to_luma8();
|
||||
let _blurred = imageproc::filter::gaussian_blur_f32(&gray, 1.5);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Benchmark threshold (binarization)
|
||||
fn benchmark_threshold(images: &[DynamicImage]) -> BenchmarkResult {
|
||||
let mut idx = 0;
|
||||
run_benchmark::<_, std::convert::Infallible>("Otsu Threshold", 100, || {
|
||||
let img = &images[idx % images.len()];
|
||||
idx += 1;
|
||||
let gray = img.to_luma8();
|
||||
let _thresholded = imageproc::contrast::threshold(&gray, 128, ThresholdType::Binary);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Benchmark adaptive threshold
|
||||
fn benchmark_adaptive_threshold(images: &[DynamicImage]) -> BenchmarkResult {
|
||||
let mut idx = 0;
|
||||
run_benchmark::<_, std::convert::Infallible>("Adaptive Threshold", 30, || {
|
||||
let img = &images[idx % images.len()];
|
||||
idx += 1;
|
||||
let gray = img.to_luma8();
|
||||
let _thresholded = imageproc::contrast::adaptive_threshold(&gray, 11);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Benchmark memory throughput
|
||||
fn benchmark_memory_throughput() -> BenchmarkResult {
|
||||
let data: Vec<f32> = (0..1_000_000).map(|i| i as f32).collect();
|
||||
|
||||
run_benchmark::<_, std::convert::Infallible>("Memory Throughput (1M floats)", 100, || {
|
||||
let _sum: f32 = data.iter().sum();
|
||||
let _clone = data.clone();
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Benchmark tensor creation for ONNX
|
||||
fn benchmark_tensor_creation() -> BenchmarkResult {
|
||||
use ndarray::Array4;
|
||||
|
||||
run_benchmark::<_, ndarray::ShapeError>("Tensor Creation (1x3x224x224)", 100, || {
|
||||
let tensor_data: Vec<f32> = vec![0.0; 1 * 3 * 224 * 224];
|
||||
let _tensor = Array4::from_shape_vec((1, 3, 224, 224), tensor_data)?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Benchmark large tensor creation
|
||||
fn benchmark_large_tensor() -> BenchmarkResult {
|
||||
use ndarray::Array4;
|
||||
|
||||
run_benchmark::<_, ndarray::ShapeError>("Large Tensor (1x3x640x480)", 50, || {
|
||||
let tensor_data: Vec<f32> = vec![0.0; 1 * 3 * 640 * 480];
|
||||
let _tensor = Array4::from_shape_vec((1, 3, 640, 480), tensor_data)?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Benchmark image normalization
|
||||
fn benchmark_normalization(images: &[DynamicImage]) -> BenchmarkResult {
|
||||
let mut idx = 0;
|
||||
run_benchmark::<_, std::convert::Infallible>("Image Normalization", 200, || {
|
||||
let img = &images[idx % images.len()];
|
||||
idx += 1;
|
||||
let rgb = img.to_rgb8();
|
||||
let mut tensor = Vec::with_capacity(3 * rgb.width() as usize * rgb.height() as usize);
|
||||
|
||||
// NCHW format normalization
|
||||
for c in 0..3 {
|
||||
for y in 0..rgb.height() {
|
||||
for x in 0..rgb.width() {
|
||||
let pixel = rgb.get_pixel(x, y);
|
||||
tensor.push((pixel[c] as f32 / 127.5) - 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Benchmark image loading from disk
|
||||
fn benchmark_image_load(path: &PathBuf) -> BenchmarkResult {
|
||||
run_benchmark::<_, image::ImageError>("Image Load from Disk", 100, || {
|
||||
let _img = image::open(path)?;
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Benchmark edge detection
|
||||
fn benchmark_edge_detection(images: &[DynamicImage]) -> BenchmarkResult {
|
||||
let mut idx = 0;
|
||||
run_benchmark::<_, std::convert::Infallible>("Sobel Edge Detection", 50, || {
|
||||
let img = &images[idx % images.len()];
|
||||
idx += 1;
|
||||
let gray = img.to_luma8();
|
||||
let _edges = imageproc::gradients::sobel_gradients(&gray);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Benchmark connected components
|
||||
fn benchmark_connected_components(images: &[DynamicImage]) -> BenchmarkResult {
|
||||
let mut idx = 0;
|
||||
run_benchmark::<_, std::convert::Infallible>("Connected Components", 50, || {
|
||||
let img = &images[idx % images.len()];
|
||||
idx += 1;
|
||||
let gray = img.to_luma8();
|
||||
let binary = imageproc::contrast::threshold(&gray, 128, ThresholdType::Binary);
|
||||
let _cc = imageproc::region_labelling::connected_components(
|
||||
&binary,
|
||||
imageproc::region_labelling::Connectivity::Eight,
|
||||
Luma([0u8]),
|
||||
);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Benchmark SIMD grayscale conversion
|
||||
fn benchmark_simd_grayscale(images: &[DynamicImage]) -> BenchmarkResult {
|
||||
let mut idx = 0;
|
||||
run_benchmark::<_, std::convert::Infallible>("SIMD Grayscale", 500, || {
|
||||
let img = &images[idx % images.len()];
|
||||
idx += 1;
|
||||
let rgba = img.to_rgba8();
|
||||
let mut gray = vec![0u8; (rgba.width() * rgba.height()) as usize];
|
||||
simd_grayscale(rgba.as_raw(), &mut gray);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Benchmark SIMD bilinear resize
|
||||
fn benchmark_simd_resize(images: &[DynamicImage]) -> BenchmarkResult {
|
||||
let mut idx = 0;
|
||||
run_benchmark::<_, std::convert::Infallible>("SIMD Resize (Bilinear)", 500, || {
|
||||
let img = &images[idx % images.len()];
|
||||
idx += 1;
|
||||
let gray = img.to_luma8();
|
||||
let _resized = simd_resize_bilinear(
|
||||
gray.as_raw(),
|
||||
gray.width() as usize,
|
||||
gray.height() as usize,
|
||||
640,
|
||||
480,
|
||||
);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Benchmark fast area resize
|
||||
fn benchmark_area_resize(images: &[DynamicImage]) -> BenchmarkResult {
|
||||
let mut idx = 0;
|
||||
run_benchmark::<_, std::convert::Infallible>("Fast Area Resize", 500, || {
|
||||
let img = &images[idx % images.len()];
|
||||
idx += 1;
|
||||
let gray = img.to_luma8();
|
||||
let _resized = fast_area_resize(
|
||||
gray.as_raw(),
|
||||
gray.width() as usize,
|
||||
gray.height() as usize,
|
||||
640,
|
||||
480,
|
||||
);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Benchmark SIMD threshold
|
||||
fn benchmark_simd_threshold(images: &[DynamicImage]) -> BenchmarkResult {
|
||||
let mut idx = 0;
|
||||
run_benchmark::<_, std::convert::Infallible>("SIMD Threshold", 500, || {
|
||||
let img = &images[idx % images.len()];
|
||||
idx += 1;
|
||||
let gray = img.to_luma8();
|
||||
let mut out = vec![0u8; gray.as_raw().len()];
|
||||
simd_threshold(gray.as_raw(), 128, &mut out);
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Complete preprocessing pipeline benchmark (SIMD optimized)
|
||||
fn benchmark_simd_pipeline(images: &[DynamicImage]) -> BenchmarkResult {
|
||||
let mut idx = 0;
|
||||
run_benchmark::<_, std::convert::Infallible>("SIMD Full Pipeline", 200, || {
|
||||
let img = &images[idx % images.len()];
|
||||
idx += 1;
|
||||
|
||||
// Step 1: RGBA to Grayscale
|
||||
let rgba = img.to_rgba8();
|
||||
let mut gray = vec![0u8; (rgba.width() * rgba.height()) as usize];
|
||||
simd_grayscale(rgba.as_raw(), &mut gray);
|
||||
|
||||
// Step 2: Resize
|
||||
let resized = simd_resize_bilinear(
|
||||
&gray,
|
||||
rgba.width() as usize,
|
||||
rgba.height() as usize,
|
||||
224,
|
||||
224,
|
||||
);
|
||||
|
||||
// Step 3: Threshold
|
||||
let mut binary = vec![0u8; resized.len()];
|
||||
simd_threshold(&resized, 128, &mut binary);
|
||||
|
||||
// Step 4: Normalize to tensor format
|
||||
let _tensor: Vec<f32> = binary.iter().map(|&x| (x as f32 / 127.5) - 1.0).collect();
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Original preprocessing pipeline benchmark (for comparison)
|
||||
fn benchmark_original_pipeline(images: &[DynamicImage]) -> BenchmarkResult {
|
||||
let mut idx = 0;
|
||||
run_benchmark::<_, std::convert::Infallible>("Original Full Pipeline", 200, || {
|
||||
let img = &images[idx % images.len()];
|
||||
idx += 1;
|
||||
|
||||
// Step 1: Grayscale
|
||||
let gray = img.to_luma8();
|
||||
|
||||
// Step 2: Resize
|
||||
let resized =
|
||||
image::imageops::resize(&gray, 224, 224, image::imageops::FilterType::Nearest);
|
||||
|
||||
// Step 3: Threshold
|
||||
let binary = imageproc::contrast::threshold(&resized, 128, ThresholdType::Binary);
|
||||
|
||||
// Step 4: Normalize
|
||||
let _tensor: Vec<f32> = binary
|
||||
.as_raw()
|
||||
.iter()
|
||||
.map(|&x| (x as f32 / 127.5) - 1.0)
|
||||
.collect();
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
println!("\n{}", "=".repeat(60));
|
||||
println!(" SciPix OCR Benchmark Suite");
|
||||
println!("{}", "=".repeat(60));
|
||||
println!("\nGenerating test images...");
|
||||
|
||||
// Generate test images
|
||||
let text_image = generate_test_image(400, 100);
|
||||
let math_image = generate_math_image(300, 150);
|
||||
let large_image = generate_test_image(800, 200);
|
||||
let hd_image = generate_test_image(1920, 1080);
|
||||
|
||||
// Save test images
|
||||
let test_dir = PathBuf::from("test_images");
|
||||
fs::create_dir_all(&test_dir)?;
|
||||
|
||||
text_image.save(test_dir.join("text_test.png"))?;
|
||||
math_image.save(test_dir.join("math_test.png"))?;
|
||||
large_image.save(test_dir.join("large_test.png"))?;
|
||||
hd_image.save(test_dir.join("hd_test.png"))?;
|
||||
|
||||
println!("Test images saved to test_images/\n");
|
||||
|
||||
// Convert to DynamicImage for benchmarks
|
||||
let images: Vec<DynamicImage> = vec![
|
||||
DynamicImage::ImageRgb8(text_image.clone()),
|
||||
DynamicImage::ImageRgb8(math_image.clone()),
|
||||
DynamicImage::ImageRgb8(large_image.clone()),
|
||||
];
|
||||
|
||||
let hd_images = vec![DynamicImage::ImageRgb8(hd_image.clone())];
|
||||
|
||||
// Run benchmarks
|
||||
let mut results = Vec::new();
|
||||
|
||||
println!("Running image conversion benchmarks...");
|
||||
results.push(benchmark_grayscale(&images));
|
||||
|
||||
println!("Running resize benchmarks...");
|
||||
results.push(benchmark_resize(&images));
|
||||
results.push(benchmark_fast_resize(&images));
|
||||
|
||||
println!("Running filter benchmarks...");
|
||||
results.push(benchmark_blur(&images));
|
||||
results.push(benchmark_threshold(&images));
|
||||
results.push(benchmark_adaptive_threshold(&images));
|
||||
results.push(benchmark_edge_detection(&images));
|
||||
results.push(benchmark_connected_components(&images));
|
||||
|
||||
println!("Running SIMD optimized benchmarks...");
|
||||
results.push(benchmark_simd_grayscale(&images));
|
||||
results.push(benchmark_simd_resize(&images));
|
||||
results.push(benchmark_area_resize(&images));
|
||||
results.push(benchmark_simd_threshold(&images));
|
||||
|
||||
println!("Running pipeline benchmarks...");
|
||||
results.push(benchmark_original_pipeline(&images));
|
||||
results.push(benchmark_simd_pipeline(&images));
|
||||
|
||||
println!("Running normalization benchmarks...");
|
||||
results.push(benchmark_normalization(&images));
|
||||
|
||||
println!("Running memory benchmarks...");
|
||||
results.push(benchmark_memory_throughput());
|
||||
results.push(benchmark_tensor_creation());
|
||||
results.push(benchmark_large_tensor());
|
||||
|
||||
println!("Running I/O benchmarks...");
|
||||
results.push(benchmark_image_load(&test_dir.join("text_test.png")));
|
||||
|
||||
println!("\nRunning HD image benchmarks...");
|
||||
results.push(run_benchmark::<_, std::convert::Infallible>(
|
||||
"HD Grayscale (1920x1080)",
|
||||
100,
|
||||
|| {
|
||||
let _gray = hd_images[0].to_luma8();
|
||||
Ok(())
|
||||
},
|
||||
));
|
||||
results.push(run_benchmark::<_, std::convert::Infallible>(
|
||||
"HD Resize to 640x480",
|
||||
50,
|
||||
|| {
|
||||
let _resized = hd_images[0].resize(640, 480, image::imageops::FilterType::Lanczos3);
|
||||
Ok(())
|
||||
},
|
||||
));
|
||||
|
||||
// Display results
|
||||
println!("\n\n{}", "#".repeat(60));
|
||||
println!(" BENCHMARK RESULTS");
|
||||
println!("{}", "#".repeat(60));
|
||||
|
||||
for result in &results {
|
||||
result.display();
|
||||
}
|
||||
|
||||
// Summary table
|
||||
println!("\n\n{}", "=".repeat(75));
|
||||
println!("{:45} {:>15} {:>15}", "Benchmark", "Avg Time", "Throughput");
|
||||
println!("{}", "-".repeat(75));
|
||||
for result in &results {
|
||||
println!(
|
||||
"{:45} {:>15.2?} {:>12.2} ops/s",
|
||||
result.name, result.avg_time, result.throughput
|
||||
);
|
||||
}
|
||||
println!("{}", "=".repeat(75));
|
||||
|
||||
// Performance analysis
|
||||
println!("\n{}", "=".repeat(60));
|
||||
println!(" PERFORMANCE ANALYSIS");
|
||||
println!("{}", "=".repeat(60));
|
||||
|
||||
// Calculate total preprocessing time for a typical pipeline
|
||||
let grayscale_time = results
|
||||
.iter()
|
||||
.find(|r| r.name == "Grayscale Conversion")
|
||||
.map(|r| r.avg_time)
|
||||
.unwrap_or_default();
|
||||
let resize_time = results
|
||||
.iter()
|
||||
.find(|r| r.name == "Fast Resize (Nearest)")
|
||||
.map(|r| r.avg_time)
|
||||
.unwrap_or_default();
|
||||
let threshold_time = results
|
||||
.iter()
|
||||
.find(|r| r.name == "Otsu Threshold")
|
||||
.map(|r| r.avg_time)
|
||||
.unwrap_or_default();
|
||||
let normalize_time = results
|
||||
.iter()
|
||||
.find(|r| r.name == "Image Normalization")
|
||||
.map(|r| r.avg_time)
|
||||
.unwrap_or_default();
|
||||
|
||||
let total_preprocess = grayscale_time + resize_time + threshold_time + normalize_time;
|
||||
|
||||
// SIMD optimized times
|
||||
let simd_grayscale = results
|
||||
.iter()
|
||||
.find(|r| r.name == "SIMD Grayscale")
|
||||
.map(|r| r.avg_time)
|
||||
.unwrap_or_default();
|
||||
let simd_resize = results
|
||||
.iter()
|
||||
.find(|r| r.name == "SIMD Resize (Bilinear)")
|
||||
.map(|r| r.avg_time)
|
||||
.unwrap_or_default();
|
||||
let simd_threshold = results
|
||||
.iter()
|
||||
.find(|r| r.name == "SIMD Threshold")
|
||||
.map(|r| r.avg_time)
|
||||
.unwrap_or_default();
|
||||
|
||||
let original_pipeline = results
|
||||
.iter()
|
||||
.find(|r| r.name == "Original Full Pipeline")
|
||||
.map(|r| r.avg_time)
|
||||
.unwrap_or_default();
|
||||
let simd_pipeline = results
|
||||
.iter()
|
||||
.find(|r| r.name == "SIMD Full Pipeline")
|
||||
.map(|r| r.avg_time)
|
||||
.unwrap_or_default();
|
||||
|
||||
println!("\n┌──────────────────────────────────────────────────────────────────┐");
|
||||
println!("│ SIMD Optimization Comparison │");
|
||||
println!("├────────────────────┬──────────────┬──────────────┬───────────────┤");
|
||||
println!("│ Operation │ Original │ SIMD │ Speedup │");
|
||||
println!("├────────────────────┼──────────────┼──────────────┼───────────────┤");
|
||||
println!(
|
||||
"│ Grayscale │ {:>10.2?} │ {:>10.2?} │ {:>6.2}x │",
|
||||
grayscale_time,
|
||||
simd_grayscale,
|
||||
if simd_grayscale.as_nanos() > 0 {
|
||||
grayscale_time.as_secs_f64() / simd_grayscale.as_secs_f64()
|
||||
} else {
|
||||
1.0
|
||||
}
|
||||
);
|
||||
println!(
|
||||
"│ Resize │ {:>10.2?} │ {:>10.2?} │ {:>6.2}x │",
|
||||
resize_time,
|
||||
simd_resize,
|
||||
if simd_resize.as_nanos() > 0 {
|
||||
resize_time.as_secs_f64() / simd_resize.as_secs_f64()
|
||||
} else {
|
||||
1.0
|
||||
}
|
||||
);
|
||||
println!(
|
||||
"│ Threshold │ {:>10.2?} │ {:>10.2?} │ {:>6.2}x │",
|
||||
threshold_time,
|
||||
simd_threshold,
|
||||
if simd_threshold.as_nanos() > 0 {
|
||||
threshold_time.as_secs_f64() / simd_threshold.as_secs_f64()
|
||||
} else {
|
||||
1.0
|
||||
}
|
||||
);
|
||||
println!("├────────────────────┼──────────────┼──────────────┼───────────────┤");
|
||||
println!(
|
||||
"│ Full Pipeline │ {:>10.2?} │ {:>10.2?} │ {:>6.2}x │",
|
||||
original_pipeline,
|
||||
simd_pipeline,
|
||||
if simd_pipeline.as_nanos() > 0 {
|
||||
original_pipeline.as_secs_f64() / simd_pipeline.as_secs_f64()
|
||||
} else {
|
||||
1.0
|
||||
}
|
||||
);
|
||||
println!("└────────────────────┴──────────────┴──────────────┴───────────────┘");
|
||||
|
||||
println!("\n┌──────────────────────────────────────────────────┐");
|
||||
println!("│ Typical Preprocessing Pipeline Breakdown │");
|
||||
println!("├──────────────────────────────────────────────────┤");
|
||||
println!(
|
||||
"│ Grayscale: {:>10.2?} ({:.1}%) │",
|
||||
grayscale_time,
|
||||
100.0 * grayscale_time.as_secs_f64() / total_preprocess.as_secs_f64()
|
||||
);
|
||||
println!(
|
||||
"│ Resize: {:>10.2?} ({:.1}%) │",
|
||||
resize_time,
|
||||
100.0 * resize_time.as_secs_f64() / total_preprocess.as_secs_f64()
|
||||
);
|
||||
println!(
|
||||
"│ Threshold: {:>10.2?} ({:.1}%) │",
|
||||
threshold_time,
|
||||
100.0 * threshold_time.as_secs_f64() / total_preprocess.as_secs_f64()
|
||||
);
|
||||
println!(
|
||||
"│ Normalization: {:>10.2?} ({:.1}%) │",
|
||||
normalize_time,
|
||||
100.0 * normalize_time.as_secs_f64() / total_preprocess.as_secs_f64()
|
||||
);
|
||||
println!("├──────────────────────────────────────────────────┤");
|
||||
println!(
|
||||
"│ TOTAL: {:>10.2?} │",
|
||||
total_preprocess
|
||||
);
|
||||
println!("└──────────────────────────────────────────────────┘");
|
||||
|
||||
println!("\nTarget latency for real-time (30 fps): 33.3ms");
|
||||
|
||||
if total_preprocess.as_millis() < 33 {
|
||||
println!(
|
||||
"✓ Preprocessing meets real-time requirements ({:.1}ms < 33.3ms)",
|
||||
total_preprocess.as_secs_f64() * 1000.0
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
"⚠ Preprocessing exceeds real-time target ({:.1}ms > 33.3ms)",
|
||||
total_preprocess.as_secs_f64() * 1000.0
|
||||
);
|
||||
}
|
||||
|
||||
// Memory efficiency
|
||||
let tensor_throughput = results
|
||||
.iter()
|
||||
.find(|r| r.name.contains("Tensor Creation"))
|
||||
.map(|r| r.throughput)
|
||||
.unwrap_or(0.0);
|
||||
|
||||
println!(
|
||||
"\nTensor creation throughput: {:.0} tensors/sec",
|
||||
tensor_throughput
|
||||
);
|
||||
println!("Target for batch inference: >100 tensors/sec");
|
||||
|
||||
if tensor_throughput > 100.0 {
|
||||
println!("✓ Tensor creation meets batch requirements");
|
||||
} else {
|
||||
println!("⚠ Consider tensor pooling optimization");
|
||||
}
|
||||
|
||||
// Estimated end-to-end throughput
|
||||
let estimated_ocr_time = total_preprocess.as_secs_f64() * 1000.0 + 50.0; // preprocessing + estimated inference
|
||||
let estimated_throughput = 1000.0 / estimated_ocr_time;
|
||||
|
||||
println!("\n┌──────────────────────────────────────────────────┐");
|
||||
println!("│ Estimated End-to-End Performance │");
|
||||
println!("├──────────────────────────────────────────────────┤");
|
||||
println!(
|
||||
"│ Preprocessing: {:>8.2}ms │",
|
||||
total_preprocess.as_secs_f64() * 1000.0
|
||||
);
|
||||
println!("│ Est. Inference: {:>8.2}ms (target) │", 50.0);
|
||||
println!(
|
||||
"│ Total latency: {:>8.2}ms │",
|
||||
estimated_ocr_time
|
||||
);
|
||||
println!(
|
||||
"│ Throughput: {:>8.1} images/sec │",
|
||||
estimated_throughput
|
||||
);
|
||||
println!("└──────────────────────────────────────────────────┘");
|
||||
|
||||
// State of the art comparison
|
||||
println!("\n{}", "=".repeat(60));
|
||||
println!(" STATE OF THE ART COMPARISON");
|
||||
println!("{}", "=".repeat(60));
|
||||
println!("\n┌────────────────────────────────────────────────────────┐");
|
||||
println!("│ System │ Latency │ Throughput │ Status │");
|
||||
println!("├────────────────────────────────────────────────────────┤");
|
||||
println!("│ Tesseract │ ~200ms │ ~5 img/s │ Slow │");
|
||||
println!("│ PaddleOCR │ ~50ms │ ~20 img/s │ Fast │");
|
||||
println!("│ EasyOCR │ ~100ms │ ~10 img/s │ Medium │");
|
||||
println!(
|
||||
"│ SciPix (est.) │ {:>6.1}ms │ {:>6.1} img/s │ {}│",
|
||||
estimated_ocr_time,
|
||||
estimated_throughput,
|
||||
if estimated_throughput > 15.0 {
|
||||
"Fast "
|
||||
} else if estimated_throughput > 8.0 {
|
||||
"Medium "
|
||||
} else {
|
||||
"Slow "
|
||||
}
|
||||
);
|
||||
println!("└────────────────────────────────────────────────────────┘");
|
||||
|
||||
println!("\n{}", "=".repeat(60));
|
||||
println!("Benchmark complete!");
|
||||
println!("{}", "=".repeat(60));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
66
examples/scipix/src/bin/cli.rs
Normal file
66
examples/scipix/src/bin/cli.rs
Normal file
@@ -0,0 +1,66 @@
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use ruvector_scipix::cli::{Cli, Commands};
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let cli = Cli::parse();
|
||||
|
||||
// Initialize logging based on verbosity
|
||||
let log_level = if cli.quiet {
|
||||
tracing::Level::ERROR
|
||||
} else if cli.verbose {
|
||||
tracing::Level::DEBUG
|
||||
} else {
|
||||
tracing::Level::INFO
|
||||
};
|
||||
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| format!("{}={}", env!("CARGO_PKG_NAME"), log_level).into()),
|
||||
)
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
|
||||
// Execute the command
|
||||
match &cli.command {
|
||||
Commands::Ocr(args) => {
|
||||
ruvector_scipix::cli::commands::ocr::execute(args.clone(), &cli).await?;
|
||||
}
|
||||
Commands::Batch(args) => {
|
||||
ruvector_scipix::cli::commands::batch::execute(args.clone(), &cli).await?;
|
||||
}
|
||||
Commands::Serve(args) => {
|
||||
ruvector_scipix::cli::commands::serve::execute(args.clone(), &cli).await?;
|
||||
}
|
||||
Commands::Mcp(args) => {
|
||||
ruvector_scipix::cli::commands::mcp::run(args.clone()).await?;
|
||||
}
|
||||
Commands::Config(args) => {
|
||||
ruvector_scipix::cli::commands::config::execute(args.clone(), &cli).await?;
|
||||
}
|
||||
Commands::Doctor(args) => {
|
||||
ruvector_scipix::cli::commands::doctor::execute(args.clone()).await?;
|
||||
}
|
||||
Commands::Version => {
|
||||
println!("scipix-cli v{}", env!("CARGO_PKG_VERSION"));
|
||||
println!("A Rust-based CLI for Scipix OCR processing");
|
||||
}
|
||||
Commands::Completions { shell } => {
|
||||
use clap::CommandFactory;
|
||||
use clap_complete::{generate, Shell};
|
||||
|
||||
let shell = shell
|
||||
.clone()
|
||||
.unwrap_or_else(|| Shell::from_env().unwrap_or(Shell::Bash));
|
||||
|
||||
let mut cmd = Cli::command();
|
||||
let bin_name = cmd.get_name().to_string();
|
||||
generate(shell, &mut cmd, bin_name, &mut std::io::stdout());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
37
examples/scipix/src/bin/server.rs
Normal file
37
examples/scipix/src/bin/server.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use anyhow::Result;
|
||||
use std::net::SocketAddr;
|
||||
use tracing::info;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
|
||||
|
||||
use ruvector_scipix::api::{state::AppState, ApiServer};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
// Initialize tracing
|
||||
tracing_subscriber::registry()
|
||||
.with(
|
||||
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||
.unwrap_or_else(|_| "scipix_server=debug,tower_http=debug,axum=trace".into()),
|
||||
)
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
|
||||
info!("Initializing Scipix API Server");
|
||||
|
||||
// Load configuration from environment
|
||||
dotenvy::dotenv().ok();
|
||||
|
||||
// Create application state
|
||||
let state = AppState::new();
|
||||
|
||||
// Parse server address
|
||||
let addr = std::env::var("SERVER_ADDR")
|
||||
.unwrap_or_else(|_| "127.0.0.1:3000".to_string())
|
||||
.parse::<SocketAddr>()?;
|
||||
|
||||
// Create and start server
|
||||
let server = ApiServer::new(state, addr);
|
||||
server.start().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
488
examples/scipix/src/cache/mod.rs
vendored
Normal file
488
examples/scipix/src/cache/mod.rs
vendored
Normal file
@@ -0,0 +1,488 @@
|
||||
//! Vector-based intelligent caching for Scipix OCR results
|
||||
//!
|
||||
//! Uses ruvector-core for efficient similarity search and LRU eviction.
|
||||
|
||||
use crate::config::CacheConfig;
|
||||
use crate::error::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
/// Cached OCR result with metadata
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CachedResult {
|
||||
/// LaTeX output
|
||||
pub latex: String,
|
||||
|
||||
/// Alternative formats (MathML, AsciiMath)
|
||||
pub alternatives: HashMap<String, String>,
|
||||
|
||||
/// Confidence score
|
||||
pub confidence: f32,
|
||||
|
||||
/// Cache timestamp
|
||||
pub timestamp: u64,
|
||||
|
||||
/// Access count
|
||||
pub access_count: usize,
|
||||
|
||||
/// Image hash
|
||||
pub image_hash: String,
|
||||
}
|
||||
|
||||
/// Cache entry with vector embedding
|
||||
#[derive(Debug, Clone)]
|
||||
struct CacheEntry {
|
||||
/// Vector embedding of image
|
||||
embedding: Vec<f32>,
|
||||
|
||||
/// Cached result
|
||||
result: CachedResult,
|
||||
|
||||
/// Last access time
|
||||
last_access: u64,
|
||||
}
|
||||
|
||||
/// Vector-based cache manager
|
||||
pub struct CacheManager {
|
||||
/// Configuration
|
||||
config: CacheConfig,
|
||||
|
||||
/// Cache entries (thread-safe)
|
||||
entries: Arc<RwLock<HashMap<String, CacheEntry>>>,
|
||||
|
||||
/// LRU tracking
|
||||
lru_order: Arc<RwLock<Vec<String>>>,
|
||||
|
||||
/// Cache statistics
|
||||
stats: Arc<RwLock<CacheStats>>,
|
||||
}
|
||||
|
||||
/// Cache statistics
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct CacheStats {
|
||||
/// Total cache hits
|
||||
pub hits: u64,
|
||||
|
||||
/// Total cache misses
|
||||
pub misses: u64,
|
||||
|
||||
/// Total entries
|
||||
pub entries: usize,
|
||||
|
||||
/// Total evictions
|
||||
pub evictions: u64,
|
||||
|
||||
/// Average similarity score for hits
|
||||
pub avg_similarity: f32,
|
||||
}
|
||||
|
||||
impl CacheStats {
|
||||
/// Calculate hit rate
|
||||
pub fn hit_rate(&self) -> f32 {
|
||||
if self.hits + self.misses == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
self.hits as f32 / (self.hits + self.misses) as f32
|
||||
}
|
||||
}
|
||||
|
||||
impl CacheManager {
|
||||
/// Create new cache manager
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `config` - Cache configuration
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use ruvector_scipix::{CacheConfig, cache::CacheManager};
|
||||
///
|
||||
/// let config = CacheConfig {
|
||||
/// enabled: true,
|
||||
/// capacity: 1000,
|
||||
/// similarity_threshold: 0.95,
|
||||
/// ttl: 3600,
|
||||
/// vector_dimension: 512,
|
||||
/// persistent: false,
|
||||
/// cache_dir: ".cache".to_string(),
|
||||
/// };
|
||||
///
|
||||
/// let cache = CacheManager::new(config);
|
||||
/// ```
|
||||
pub fn new(config: CacheConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
entries: Arc::new(RwLock::new(HashMap::new())),
|
||||
lru_order: Arc::new(RwLock::new(Vec::new())),
|
||||
stats: Arc::new(RwLock::new(CacheStats::default())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate embedding for image
|
||||
///
|
||||
/// This is a placeholder - in production, use actual vision model
|
||||
fn generate_embedding(&self, image_data: &[u8]) -> Result<Vec<f32>> {
|
||||
// Placeholder: Simple hash-based embedding
|
||||
// In production: Use Vision Transformer or similar
|
||||
let hash = self.hash_image(image_data);
|
||||
let mut embedding = vec![0.0; self.config.vector_dimension];
|
||||
|
||||
for (i, byte) in hash.as_bytes().iter().enumerate() {
|
||||
if i < embedding.len() {
|
||||
embedding[i] = *byte as f32 / 255.0;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(embedding)
|
||||
}
|
||||
|
||||
/// Hash image data
|
||||
fn hash_image(&self, image_data: &[u8]) -> String {
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::hash::{Hash, Hasher};
|
||||
|
||||
let mut hasher = DefaultHasher::new();
|
||||
image_data.hash(&mut hasher);
|
||||
format!("{:x}", hasher.finish())
|
||||
}
|
||||
|
||||
/// Calculate cosine similarity between vectors
|
||||
fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
|
||||
if a.len() != b.len() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
|
||||
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
|
||||
|
||||
if norm_a == 0.0 || norm_b == 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
dot_product / (norm_a * norm_b)
|
||||
}
|
||||
|
||||
/// Look up cached result by image similarity
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `image_data` - Raw image bytes
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Cached result if similarity exceeds threshold, None otherwise
|
||||
pub fn lookup(&self, image_data: &[u8]) -> Result<Option<CachedResult>> {
|
||||
if !self.config.enabled {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let embedding = self.generate_embedding(image_data)?;
|
||||
let hash = self.hash_image(image_data);
|
||||
|
||||
let entries = self.entries.read().unwrap();
|
||||
|
||||
// First try exact hash match
|
||||
if let Some(entry) = entries.get(&hash) {
|
||||
if !self.is_expired(&entry) {
|
||||
self.record_hit();
|
||||
self.update_lru(&hash);
|
||||
return Ok(Some(entry.result.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
// Then try similarity search
|
||||
let mut best_match: Option<(String, f32, CachedResult)> = None;
|
||||
|
||||
for (key, entry) in entries.iter() {
|
||||
if self.is_expired(entry) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let similarity = self.cosine_similarity(&embedding, &entry.embedding);
|
||||
|
||||
if similarity >= self.config.similarity_threshold {
|
||||
if best_match.is_none() || similarity > best_match.as_ref().unwrap().1 {
|
||||
best_match = Some((key.clone(), similarity, entry.result.clone()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some((key, similarity, result)) = best_match {
|
||||
self.record_hit_with_similarity(similarity);
|
||||
self.update_lru(&key);
|
||||
Ok(Some(result))
|
||||
} else {
|
||||
self.record_miss();
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Store result in cache
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `image_data` - Raw image bytes
|
||||
/// * `result` - OCR result to cache
|
||||
pub fn store(&self, image_data: &[u8], result: CachedResult) -> Result<()> {
|
||||
if !self.config.enabled {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let embedding = self.generate_embedding(image_data)?;
|
||||
let hash = self.hash_image(image_data);
|
||||
|
||||
let entry = CacheEntry {
|
||||
embedding,
|
||||
result,
|
||||
last_access: self.current_timestamp(),
|
||||
};
|
||||
|
||||
let mut entries = self.entries.write().unwrap();
|
||||
|
||||
// Check if we need to evict
|
||||
if entries.len() >= self.config.capacity && !entries.contains_key(&hash) {
|
||||
self.evict_lru(&mut entries);
|
||||
}
|
||||
|
||||
entries.insert(hash.clone(), entry);
|
||||
self.update_lru(&hash);
|
||||
self.update_stats_entries(entries.len());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if entry is expired
|
||||
fn is_expired(&self, entry: &CacheEntry) -> bool {
|
||||
let current = self.current_timestamp();
|
||||
current - entry.last_access > self.config.ttl
|
||||
}
|
||||
|
||||
/// Get current timestamp
|
||||
fn current_timestamp(&self) -> u64 {
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs()
|
||||
}
|
||||
|
||||
/// Evict least recently used entry
|
||||
fn evict_lru(&self, entries: &mut HashMap<String, CacheEntry>) {
|
||||
let mut lru = self.lru_order.write().unwrap();
|
||||
|
||||
if let Some(key) = lru.first() {
|
||||
entries.remove(key);
|
||||
lru.remove(0);
|
||||
self.record_eviction();
|
||||
}
|
||||
}
|
||||
|
||||
/// Update LRU order
|
||||
fn update_lru(&self, key: &str) {
|
||||
let mut lru = self.lru_order.write().unwrap();
|
||||
lru.retain(|k| k != key);
|
||||
lru.push(key.to_string());
|
||||
}
|
||||
|
||||
/// Record cache hit
|
||||
fn record_hit(&self) {
|
||||
let mut stats = self.stats.write().unwrap();
|
||||
stats.hits += 1;
|
||||
}
|
||||
|
||||
/// Record cache hit with similarity
|
||||
fn record_hit_with_similarity(&self, similarity: f32) {
|
||||
let mut stats = self.stats.write().unwrap();
|
||||
stats.hits += 1;
|
||||
|
||||
// Update rolling average
|
||||
let total = stats.hits as f32;
|
||||
stats.avg_similarity = (stats.avg_similarity * (total - 1.0) + similarity) / total;
|
||||
}
|
||||
|
||||
/// Record cache miss
|
||||
fn record_miss(&self) {
|
||||
let mut stats = self.stats.write().unwrap();
|
||||
stats.misses += 1;
|
||||
}
|
||||
|
||||
/// Record eviction
|
||||
fn record_eviction(&self) {
|
||||
let mut stats = self.stats.write().unwrap();
|
||||
stats.evictions += 1;
|
||||
}
|
||||
|
||||
/// Update entry count
|
||||
fn update_stats_entries(&self, count: usize) {
|
||||
let mut stats = self.stats.write().unwrap();
|
||||
stats.entries = count;
|
||||
}
|
||||
|
||||
/// Get cache statistics
|
||||
pub fn stats(&self) -> CacheStats {
|
||||
self.stats.read().unwrap().clone()
|
||||
}
|
||||
|
||||
/// Clear all cache entries
|
||||
pub fn clear(&self) {
|
||||
let mut entries = self.entries.write().unwrap();
|
||||
let mut lru = self.lru_order.write().unwrap();
|
||||
|
||||
entries.clear();
|
||||
lru.clear();
|
||||
|
||||
self.update_stats_entries(0);
|
||||
}
|
||||
|
||||
/// Remove expired entries
|
||||
pub fn cleanup(&self) {
|
||||
let mut entries = self.entries.write().unwrap();
|
||||
let mut lru = self.lru_order.write().unwrap();
|
||||
|
||||
let expired: Vec<String> = entries
|
||||
.iter()
|
||||
.filter(|(_, entry)| self.is_expired(entry))
|
||||
.map(|(key, _)| key.clone())
|
||||
.collect();
|
||||
|
||||
for key in &expired {
|
||||
entries.remove(key);
|
||||
lru.retain(|k| k != key);
|
||||
}
|
||||
|
||||
self.update_stats_entries(entries.len());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_config() -> CacheConfig {
|
||||
CacheConfig {
|
||||
enabled: true,
|
||||
capacity: 100,
|
||||
similarity_threshold: 0.95,
|
||||
ttl: 3600,
|
||||
vector_dimension: 128,
|
||||
persistent: false,
|
||||
cache_dir: ".cache/test".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn test_result() -> CachedResult {
|
||||
CachedResult {
|
||||
latex: r"\frac{x^2}{2}".to_string(),
|
||||
alternatives: HashMap::new(),
|
||||
confidence: 0.95,
|
||||
timestamp: 0,
|
||||
access_count: 0,
|
||||
image_hash: "test".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_creation() {
|
||||
let config = test_config();
|
||||
let cache = CacheManager::new(config);
|
||||
assert_eq!(cache.stats().hits, 0);
|
||||
assert_eq!(cache.stats().misses, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_store_and_lookup() {
|
||||
let config = test_config();
|
||||
let cache = CacheManager::new(config);
|
||||
|
||||
let image_data = b"test image data";
|
||||
let result = test_result();
|
||||
|
||||
cache.store(image_data, result.clone()).unwrap();
|
||||
|
||||
let lookup_result = cache.lookup(image_data).unwrap();
|
||||
assert!(lookup_result.is_some());
|
||||
assert_eq!(lookup_result.unwrap().latex, result.latex);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_miss() {
|
||||
let config = test_config();
|
||||
let cache = CacheManager::new(config);
|
||||
|
||||
let image_data = b"nonexistent image";
|
||||
let lookup_result = cache.lookup(image_data).unwrap();
|
||||
|
||||
assert!(lookup_result.is_none());
|
||||
assert_eq!(cache.stats().misses, 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_hit_rate() {
|
||||
let config = test_config();
|
||||
let cache = CacheManager::new(config);
|
||||
|
||||
let image_data = b"test image";
|
||||
let result = test_result();
|
||||
|
||||
// Store and lookup once
|
||||
cache.store(image_data, result).unwrap();
|
||||
cache.lookup(image_data).unwrap();
|
||||
|
||||
// Lookup again (hit)
|
||||
cache.lookup(image_data).unwrap();
|
||||
|
||||
// Lookup different image (miss)
|
||||
cache.lookup(b"different image").unwrap();
|
||||
|
||||
let stats = cache.stats();
|
||||
assert_eq!(stats.hits, 2);
|
||||
assert_eq!(stats.misses, 1);
|
||||
assert!((stats.hit_rate() - 0.666).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cosine_similarity() {
|
||||
let config = test_config();
|
||||
let cache = CacheManager::new(config);
|
||||
|
||||
let vec_a = vec![1.0, 0.0, 0.0];
|
||||
let vec_b = vec![1.0, 0.0, 0.0];
|
||||
let vec_c = vec![0.0, 1.0, 0.0];
|
||||
|
||||
assert!((cache.cosine_similarity(&vec_a, &vec_b) - 1.0).abs() < 0.01);
|
||||
assert!((cache.cosine_similarity(&vec_a, &vec_c) - 0.0).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cache_clear() {
|
||||
let config = test_config();
|
||||
let cache = CacheManager::new(config);
|
||||
|
||||
let image_data = b"test image";
|
||||
let result = test_result();
|
||||
|
||||
cache.store(image_data, result).unwrap();
|
||||
assert_eq!(cache.stats().entries, 1);
|
||||
|
||||
cache.clear();
|
||||
assert_eq!(cache.stats().entries, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_disabled_cache() {
|
||||
let mut config = test_config();
|
||||
config.enabled = false;
|
||||
let cache = CacheManager::new(config);
|
||||
|
||||
let image_data = b"test image";
|
||||
let result = test_result();
|
||||
|
||||
cache.store(image_data, result).unwrap();
|
||||
let lookup_result = cache.lookup(image_data).unwrap();
|
||||
|
||||
assert!(lookup_result.is_none());
|
||||
}
|
||||
}
|
||||
399
examples/scipix/src/cli/commands/batch.rs
Normal file
399
examples/scipix/src/cli/commands/batch.rs
Normal file
@@ -0,0 +1,399 @@
|
||||
use anyhow::{Context, Result};
|
||||
use clap::Args;
|
||||
use glob::glob;
|
||||
use indicatif::{MultiProgress, ProgressBar, ProgressStyle};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Semaphore;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use super::{OcrConfig, OcrResult};
|
||||
use crate::cli::{output, Cli, OutputFormat};
|
||||
|
||||
/// Process multiple files in batch mode
|
||||
#[derive(Args, Debug, Clone)]
|
||||
pub struct BatchArgs {
|
||||
/// Input pattern (glob) or directory
|
||||
#[arg(value_name = "PATTERN", help = "Input pattern (glob) or directory")]
|
||||
pub pattern: String,
|
||||
|
||||
/// Output directory for results
|
||||
#[arg(
|
||||
short,
|
||||
long,
|
||||
value_name = "DIR",
|
||||
help = "Output directory for results (default: stdout as JSON array)"
|
||||
)]
|
||||
pub output: Option<PathBuf>,
|
||||
|
||||
/// Number of parallel workers
|
||||
#[arg(
|
||||
short,
|
||||
long,
|
||||
default_value = "4",
|
||||
help = "Number of parallel processing workers"
|
||||
)]
|
||||
pub parallel: usize,
|
||||
|
||||
/// Minimum confidence threshold (0.0 to 1.0)
|
||||
#[arg(
|
||||
short = 't',
|
||||
long,
|
||||
default_value = "0.7",
|
||||
help = "Minimum confidence threshold for results"
|
||||
)]
|
||||
pub threshold: f64,
|
||||
|
||||
/// Continue on errors
|
||||
#[arg(
|
||||
short = 'c',
|
||||
long,
|
||||
help = "Continue processing even if some files fail"
|
||||
)]
|
||||
pub continue_on_error: bool,
|
||||
|
||||
/// Maximum retry attempts per file
|
||||
#[arg(
|
||||
short = 'r',
|
||||
long,
|
||||
default_value = "2",
|
||||
help = "Maximum retry attempts per file on failure"
|
||||
)]
|
||||
pub max_retries: usize,
|
||||
|
||||
/// Save individual results as separate files
|
||||
#[arg(long, help = "Save each result as a separate file (requires --output)")]
|
||||
pub separate_files: bool,
|
||||
|
||||
/// Recursive directory search
|
||||
#[arg(short = 'R', long, help = "Recursively search directories")]
|
||||
pub recursive: bool,
|
||||
}
|
||||
|
||||
pub async fn execute(args: BatchArgs, cli: &Cli) -> Result<()> {
|
||||
info!("Starting batch processing with pattern: {}", args.pattern);
|
||||
|
||||
// Load configuration
|
||||
let config = Arc::new(load_config(cli.config.as_ref())?);
|
||||
|
||||
// Expand pattern to file list
|
||||
let files = collect_files(&args)?;
|
||||
|
||||
if files.is_empty() {
|
||||
anyhow::bail!("No files found matching pattern: {}", args.pattern);
|
||||
}
|
||||
|
||||
info!("Found {} files to process", files.len());
|
||||
|
||||
// Create output directory if needed
|
||||
if let Some(output_dir) = &args.output {
|
||||
std::fs::create_dir_all(output_dir).context("Failed to create output directory")?;
|
||||
}
|
||||
|
||||
// Process files in parallel with progress bars
|
||||
let results = process_files_parallel(files, &args, &config, cli.quiet).await?;
|
||||
|
||||
// Filter by confidence threshold
|
||||
let (passed, failed): (Vec<_>, Vec<_>) = results
|
||||
.into_iter()
|
||||
.partition(|r| r.confidence >= args.threshold);
|
||||
|
||||
info!(
|
||||
"Processing complete: {} passed, {} failed threshold",
|
||||
passed.len(),
|
||||
failed.len()
|
||||
);
|
||||
|
||||
// Save or display results
|
||||
if let Some(output_dir) = &args.output {
|
||||
save_results(&passed, output_dir, &cli.format, args.separate_files)?;
|
||||
|
||||
if !cli.quiet {
|
||||
println!("Results saved to: {}", output_dir.display());
|
||||
}
|
||||
} else {
|
||||
// Output as JSON array to stdout
|
||||
let json = serde_json::to_string_pretty(&passed).context("Failed to serialize results")?;
|
||||
println!("{}", json);
|
||||
}
|
||||
|
||||
// Display summary
|
||||
if !cli.quiet {
|
||||
output::print_batch_summary(&passed, &failed, args.threshold);
|
||||
}
|
||||
|
||||
// Return error if any files failed and continue_on_error is false
|
||||
if !failed.is_empty() && !args.continue_on_error {
|
||||
anyhow::bail!("{} files failed confidence threshold", failed.len());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn collect_files(args: &BatchArgs) -> Result<Vec<PathBuf>> {
|
||||
let mut files = Vec::new();
|
||||
let path = PathBuf::from(&args.pattern);
|
||||
|
||||
if path.is_dir() {
|
||||
// Directory mode
|
||||
let pattern = if args.recursive {
|
||||
format!("{}/**/*", args.pattern)
|
||||
} else {
|
||||
format!("{}/*", args.pattern)
|
||||
};
|
||||
|
||||
for entry in glob(&pattern).context("Failed to read glob pattern")? {
|
||||
match entry {
|
||||
Ok(path) => {
|
||||
if path.is_file() {
|
||||
files.push(path);
|
||||
}
|
||||
}
|
||||
Err(e) => warn!("Failed to read entry: {}", e),
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Glob pattern mode
|
||||
for entry in glob(&args.pattern).context("Failed to read glob pattern")? {
|
||||
match entry {
|
||||
Ok(path) => {
|
||||
if path.is_file() {
|
||||
files.push(path);
|
||||
}
|
||||
}
|
||||
Err(e) => warn!("Failed to read entry: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(files)
|
||||
}
|
||||
|
||||
async fn process_files_parallel(
|
||||
files: Vec<PathBuf>,
|
||||
args: &BatchArgs,
|
||||
config: &Arc<OcrConfig>,
|
||||
quiet: bool,
|
||||
) -> Result<Vec<OcrResult>> {
|
||||
let semaphore = Arc::new(Semaphore::new(args.parallel));
|
||||
let multi_progress = Arc::new(MultiProgress::new());
|
||||
|
||||
let overall_progress = if !quiet {
|
||||
let pb = multi_progress.add(ProgressBar::new(files.len() as u64));
|
||||
pb.set_style(
|
||||
ProgressStyle::default_bar()
|
||||
.template(
|
||||
"{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})",
|
||||
)
|
||||
.unwrap()
|
||||
.progress_chars("#>-"),
|
||||
);
|
||||
Some(pb)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut handles = Vec::new();
|
||||
|
||||
for (_idx, file) in files.into_iter().enumerate() {
|
||||
let semaphore = semaphore.clone();
|
||||
let config = config.clone();
|
||||
let multi_progress = multi_progress.clone();
|
||||
let overall_progress = overall_progress.clone();
|
||||
let max_retries = args.max_retries;
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let _permit = semaphore.acquire().await.unwrap();
|
||||
|
||||
let file_progress = if !quiet {
|
||||
let pb = multi_progress.insert_before(
|
||||
&overall_progress.as_ref().unwrap(),
|
||||
ProgressBar::new_spinner(),
|
||||
);
|
||||
pb.set_style(
|
||||
ProgressStyle::default_spinner()
|
||||
.template("{spinner:.green} {msg}")
|
||||
.unwrap(),
|
||||
);
|
||||
pb.set_message(format!("[{}] Processing...", file.display()));
|
||||
Some(pb)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let result = process_with_retry(&file, &config, max_retries).await;
|
||||
|
||||
if let Some(pb) = &file_progress {
|
||||
match &result {
|
||||
Ok(r) => pb.finish_with_message(format!(
|
||||
"[{}] ✓ Confidence: {:.2}%",
|
||||
file.display(),
|
||||
r.confidence * 100.0
|
||||
)),
|
||||
Err(e) => {
|
||||
pb.finish_with_message(format!("[{}] ✗ Error: {}", file.display(), e))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(pb) = &overall_progress {
|
||||
pb.inc(1);
|
||||
}
|
||||
|
||||
result
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all tasks to complete
|
||||
let mut results = Vec::new();
|
||||
for handle in handles {
|
||||
match handle.await {
|
||||
Ok(Ok(result)) => results.push(result),
|
||||
Ok(Err(e)) => error!("Processing failed: {}", e),
|
||||
Err(e) => error!("Task panicked: {}", e),
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(pb) = overall_progress {
|
||||
pb.finish_with_message("Batch processing complete");
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
async fn process_with_retry(
|
||||
file: &PathBuf,
|
||||
config: &OcrConfig,
|
||||
max_retries: usize,
|
||||
) -> Result<OcrResult> {
|
||||
let mut attempts = 0;
|
||||
let mut last_error = None;
|
||||
|
||||
while attempts <= max_retries {
|
||||
match process_single_file(file, config).await {
|
||||
Ok(result) => return Ok(result),
|
||||
Err(e) => {
|
||||
attempts += 1;
|
||||
last_error = Some(e);
|
||||
|
||||
if attempts <= max_retries {
|
||||
debug!("Retry {}/{} for {}", attempts, max_retries, file.display());
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100 * attempts as u64))
|
||||
.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(last_error.unwrap())
|
||||
}
|
||||
|
||||
async fn process_single_file(file: &PathBuf, _config: &OcrConfig) -> Result<OcrResult> {
|
||||
// TODO: Implement actual OCR processing
|
||||
// For now, return a mock result
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||
|
||||
// Simulate varying confidence
|
||||
let confidence = 0.7 + (rand::random::<f64>() * 0.3);
|
||||
|
||||
Ok(OcrResult {
|
||||
file: file.clone(),
|
||||
text: format!("OCR text from {}", file.display()),
|
||||
latex: Some(format!(r"\text{{Content from {}}}", file.display())),
|
||||
confidence,
|
||||
processing_time_ms: 50,
|
||||
errors: Vec::new(),
|
||||
})
|
||||
}
|
||||
|
||||
fn save_results(
|
||||
results: &[OcrResult],
|
||||
output_dir: &PathBuf,
|
||||
format: &OutputFormat,
|
||||
separate_files: bool,
|
||||
) -> Result<()> {
|
||||
if separate_files {
|
||||
// Save each result as a separate file
|
||||
for (idx, result) in results.iter().enumerate() {
|
||||
let filename = format!(
|
||||
"result_{:04}.{}",
|
||||
idx,
|
||||
match format {
|
||||
OutputFormat::Json => "json",
|
||||
OutputFormat::Latex => "tex",
|
||||
OutputFormat::Markdown => "md",
|
||||
OutputFormat::MathMl => "xml",
|
||||
OutputFormat::Text => "txt",
|
||||
}
|
||||
);
|
||||
|
||||
let output_path = output_dir.join(filename);
|
||||
let content = format_single_result(result, format)?;
|
||||
|
||||
std::fs::write(&output_path, content)
|
||||
.context(format!("Failed to write {}", output_path.display()))?;
|
||||
}
|
||||
} else {
|
||||
// Save all results as a single file
|
||||
let filename = format!(
|
||||
"results.{}",
|
||||
match format {
|
||||
OutputFormat::Json => "json",
|
||||
OutputFormat::Latex => "tex",
|
||||
OutputFormat::Markdown => "md",
|
||||
OutputFormat::MathMl => "xml",
|
||||
OutputFormat::Text => "txt",
|
||||
}
|
||||
);
|
||||
|
||||
let output_path = output_dir.join(filename);
|
||||
let content = format_batch_results(results, format)?;
|
||||
|
||||
std::fs::write(&output_path, content).context("Failed to write results file")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn format_single_result(result: &OcrResult, format: &OutputFormat) -> Result<String> {
|
||||
match format {
|
||||
OutputFormat::Json => {
|
||||
serde_json::to_string_pretty(result).context("Failed to serialize result")
|
||||
}
|
||||
OutputFormat::Text => Ok(result.text.clone()),
|
||||
OutputFormat::Latex => Ok(result.latex.clone().unwrap_or_else(|| result.text.clone())),
|
||||
OutputFormat::Markdown => Ok(format!("# {}\n\n{}\n", result.file.display(), result.text)),
|
||||
OutputFormat::MathMl => Ok(format!(
|
||||
"<math xmlns=\"http://www.w3.org/1998/Math/MathML\">\n {}\n</math>",
|
||||
result.text
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn format_batch_results(results: &[OcrResult], format: &OutputFormat) -> Result<String> {
|
||||
match format {
|
||||
OutputFormat::Json => {
|
||||
serde_json::to_string_pretty(results).context("Failed to serialize results")
|
||||
}
|
||||
_ => {
|
||||
let mut output = String::new();
|
||||
for result in results {
|
||||
output.push_str(&format_single_result(result, format)?);
|
||||
output.push_str("\n\n---\n\n");
|
||||
}
|
||||
Ok(output)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_config(config_path: Option<&PathBuf>) -> Result<OcrConfig> {
|
||||
if let Some(path) = config_path {
|
||||
let content = std::fs::read_to_string(path).context("Failed to read config file")?;
|
||||
toml::from_str(&content).context("Failed to parse config file")
|
||||
} else {
|
||||
Ok(OcrConfig::default())
|
||||
}
|
||||
}
|
||||
272
examples/scipix/src/cli/commands/config.rs
Normal file
272
examples/scipix/src/cli/commands/config.rs
Normal file
@@ -0,0 +1,272 @@
|
||||
use anyhow::{Context, Result};
|
||||
use clap::{Args, Subcommand};
|
||||
use dialoguer::{theme::ColorfulTheme, Confirm, Input};
|
||||
use std::path::PathBuf;
|
||||
use tracing::info;
|
||||
|
||||
use super::OcrConfig;
|
||||
use crate::cli::Cli;
|
||||
|
||||
/// Manage configuration
|
||||
#[derive(Args, Debug, Clone)]
|
||||
pub struct ConfigArgs {
|
||||
#[command(subcommand)]
|
||||
pub command: ConfigCommand,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug, Clone)]
|
||||
pub enum ConfigCommand {
|
||||
/// Generate default configuration file
|
||||
Init {
|
||||
/// Output path for config file
|
||||
#[arg(short, long, default_value = "scipix.toml")]
|
||||
output: PathBuf,
|
||||
|
||||
/// Overwrite existing file
|
||||
#[arg(short, long)]
|
||||
force: bool,
|
||||
},
|
||||
|
||||
/// Validate configuration file
|
||||
Validate {
|
||||
/// Path to config file to validate
|
||||
#[arg(value_name = "FILE")]
|
||||
file: PathBuf,
|
||||
},
|
||||
|
||||
/// Show current configuration
|
||||
Show {
|
||||
/// Path to config file (default: from --config or scipix.toml)
|
||||
#[arg(value_name = "FILE")]
|
||||
file: Option<PathBuf>,
|
||||
},
|
||||
|
||||
/// Edit configuration interactively
|
||||
Edit {
|
||||
/// Path to config file to edit
|
||||
#[arg(value_name = "FILE")]
|
||||
file: PathBuf,
|
||||
},
|
||||
|
||||
/// Get configuration directory path
|
||||
Path,
|
||||
}
|
||||
|
||||
pub async fn execute(args: ConfigArgs, cli: &Cli) -> Result<()> {
|
||||
match args.command {
|
||||
ConfigCommand::Init { output, force } => {
|
||||
init_config(&output, force)?;
|
||||
}
|
||||
ConfigCommand::Validate { file } => {
|
||||
validate_config(&file)?;
|
||||
}
|
||||
ConfigCommand::Show { file } => {
|
||||
show_config(file.or(cli.config.clone()))?;
|
||||
}
|
||||
ConfigCommand::Edit { file } => {
|
||||
edit_config(&file)?;
|
||||
}
|
||||
ConfigCommand::Path => {
|
||||
show_config_path()?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn init_config(output: &PathBuf, force: bool) -> Result<()> {
|
||||
if output.exists() && !force {
|
||||
anyhow::bail!(
|
||||
"Config file already exists: {} (use --force to overwrite)",
|
||||
output.display()
|
||||
);
|
||||
}
|
||||
|
||||
let config = OcrConfig::default();
|
||||
let toml = toml::to_string_pretty(&config).context("Failed to serialize config")?;
|
||||
|
||||
std::fs::write(output, toml).context("Failed to write config file")?;
|
||||
|
||||
info!("Configuration file created: {}", output.display());
|
||||
println!("✓ Created configuration file: {}", output.display());
|
||||
println!("\nTo use this config, run:");
|
||||
println!(" scipix-cli --config {} <command>", output.display());
|
||||
println!("\nOr set environment variable:");
|
||||
println!(" export MATHPIX_CONFIG={}", output.display());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_config(file: &PathBuf) -> Result<()> {
|
||||
if !file.exists() {
|
||||
anyhow::bail!("Config file not found: {}", file.display());
|
||||
}
|
||||
|
||||
let content = std::fs::read_to_string(file).context("Failed to read config file")?;
|
||||
|
||||
let config: OcrConfig = toml::from_str(&content).context("Failed to parse config file")?;
|
||||
|
||||
// Validate configuration values
|
||||
if config.min_confidence < 0.0 || config.min_confidence > 1.0 {
|
||||
anyhow::bail!("min_confidence must be between 0.0 and 1.0");
|
||||
}
|
||||
|
||||
if config.max_image_size == 0 {
|
||||
anyhow::bail!("max_image_size must be greater than 0");
|
||||
}
|
||||
|
||||
if config.supported_extensions.is_empty() {
|
||||
anyhow::bail!("supported_extensions cannot be empty");
|
||||
}
|
||||
|
||||
println!("✓ Configuration is valid");
|
||||
println!("\nSettings:");
|
||||
println!(" Min confidence: {}", config.min_confidence);
|
||||
println!(" Max image size: {} bytes", config.max_image_size);
|
||||
println!(
|
||||
" Supported extensions: {}",
|
||||
config.supported_extensions.join(", ")
|
||||
);
|
||||
|
||||
if let Some(endpoint) = &config.api_endpoint {
|
||||
println!(" API endpoint: {}", endpoint);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn show_config(file: Option<PathBuf>) -> Result<()> {
|
||||
let config_path = file.unwrap_or_else(|| PathBuf::from("scipix.toml"));
|
||||
|
||||
if !config_path.exists() {
|
||||
println!("No configuration file found.");
|
||||
println!("\nCreate one with:");
|
||||
println!(" scipix-cli config init");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let content = std::fs::read_to_string(&config_path).context("Failed to read config file")?;
|
||||
|
||||
println!("Configuration from: {}\n", config_path.display());
|
||||
println!("{}", content);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn edit_config(file: &PathBuf) -> Result<()> {
|
||||
if !file.exists() {
|
||||
anyhow::bail!(
|
||||
"Config file not found: {} (use 'config init' to create)",
|
||||
file.display()
|
||||
);
|
||||
}
|
||||
|
||||
let content = std::fs::read_to_string(file).context("Failed to read config file")?;
|
||||
|
||||
let mut config: OcrConfig = toml::from_str(&content).context("Failed to parse config file")?;
|
||||
|
||||
let theme = ColorfulTheme::default();
|
||||
|
||||
println!("Interactive Configuration Editor\n");
|
||||
|
||||
// Edit min_confidence
|
||||
config.min_confidence = Input::with_theme(&theme)
|
||||
.with_prompt("Minimum confidence threshold (0.0-1.0)")
|
||||
.default(config.min_confidence)
|
||||
.validate_with(|v: &f64| {
|
||||
if *v >= 0.0 && *v <= 1.0 {
|
||||
Ok(())
|
||||
} else {
|
||||
Err("Value must be between 0.0 and 1.0")
|
||||
}
|
||||
})
|
||||
.interact_text()
|
||||
.context("Failed to read input")?;
|
||||
|
||||
// Edit max_image_size
|
||||
let max_size_mb = config.max_image_size / (1024 * 1024);
|
||||
let new_size_mb: usize = Input::with_theme(&theme)
|
||||
.with_prompt("Maximum image size (MB)")
|
||||
.default(max_size_mb)
|
||||
.interact_text()
|
||||
.context("Failed to read input")?;
|
||||
config.max_image_size = new_size_mb * 1024 * 1024;
|
||||
|
||||
// Edit API endpoint
|
||||
if config.api_endpoint.is_some() {
|
||||
let edit_endpoint = Confirm::with_theme(&theme)
|
||||
.with_prompt("Edit API endpoint?")
|
||||
.default(false)
|
||||
.interact()
|
||||
.context("Failed to read input")?;
|
||||
|
||||
if edit_endpoint {
|
||||
let endpoint: String = Input::with_theme(&theme)
|
||||
.with_prompt("API endpoint URL")
|
||||
.allow_empty(true)
|
||||
.interact_text()
|
||||
.context("Failed to read input")?;
|
||||
|
||||
config.api_endpoint = if endpoint.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(endpoint)
|
||||
};
|
||||
}
|
||||
} else {
|
||||
let add_endpoint = Confirm::with_theme(&theme)
|
||||
.with_prompt("Add API endpoint?")
|
||||
.default(false)
|
||||
.interact()
|
||||
.context("Failed to read input")?;
|
||||
|
||||
if add_endpoint {
|
||||
let endpoint: String = Input::with_theme(&theme)
|
||||
.with_prompt("API endpoint URL")
|
||||
.interact_text()
|
||||
.context("Failed to read input")?;
|
||||
|
||||
config.api_endpoint = Some(endpoint);
|
||||
}
|
||||
}
|
||||
|
||||
// Save configuration
|
||||
let save = Confirm::with_theme(&theme)
|
||||
.with_prompt("Save changes?")
|
||||
.default(true)
|
||||
.interact()
|
||||
.context("Failed to read input")?;
|
||||
|
||||
if save {
|
||||
let toml = toml::to_string_pretty(&config).context("Failed to serialize config")?;
|
||||
|
||||
std::fs::write(file, toml).context("Failed to write config file")?;
|
||||
|
||||
println!("\n✓ Configuration saved to: {}", file.display());
|
||||
} else {
|
||||
println!("\nChanges discarded.");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn show_config_path() -> Result<()> {
|
||||
if let Some(config_dir) = dirs::config_dir() {
|
||||
let app_config = config_dir.join("scipix");
|
||||
println!("Default config directory: {}", app_config.display());
|
||||
|
||||
if !app_config.exists() {
|
||||
println!("\nDirectory does not exist. Create it with:");
|
||||
println!(" mkdir -p {}", app_config.display());
|
||||
}
|
||||
} else {
|
||||
println!("Could not determine config directory");
|
||||
}
|
||||
|
||||
println!("\nYou can also use a custom config file:");
|
||||
println!(" scipix-cli --config /path/to/config.toml <command>");
|
||||
println!("\nOr set environment variable:");
|
||||
println!(" export MATHPIX_CONFIG=/path/to/config.toml");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
955
examples/scipix/src/cli/commands/doctor.rs
Normal file
955
examples/scipix/src/cli/commands/doctor.rs
Normal file
@@ -0,0 +1,955 @@
|
||||
//! Doctor command for environment analysis and configuration optimization
|
||||
//!
|
||||
//! Analyzes the system environment and provides recommendations for optimal
|
||||
//! SciPix configuration based on available hardware and software capabilities.
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::Args;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Arguments for the doctor command
|
||||
#[derive(Args, Debug, Clone)]
|
||||
pub struct DoctorArgs {
|
||||
/// Run in fix mode to automatically apply recommendations
|
||||
#[arg(long, help = "Automatically apply safe fixes")]
|
||||
pub fix: bool,
|
||||
|
||||
/// Output detailed diagnostic information
|
||||
#[arg(long, short, help = "Show detailed diagnostic information")]
|
||||
pub verbose: bool,
|
||||
|
||||
/// Output results as JSON
|
||||
#[arg(long, help = "Output results as JSON")]
|
||||
pub json: bool,
|
||||
|
||||
/// Check only specific category (cpu, memory, config, deps, all)
|
||||
#[arg(long, default_value = "all", help = "Category to check")]
|
||||
pub check: CheckCategory,
|
||||
|
||||
/// Path to configuration file to validate
|
||||
#[arg(long, help = "Path to configuration file to validate")]
|
||||
pub config_path: Option<PathBuf>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, clap::ValueEnum, Default)]
|
||||
pub enum CheckCategory {
|
||||
#[default]
|
||||
All,
|
||||
Cpu,
|
||||
Memory,
|
||||
Config,
|
||||
Deps,
|
||||
Network,
|
||||
}
|
||||
|
||||
/// Status of a diagnostic check
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum CheckStatus {
|
||||
Pass,
|
||||
Warning,
|
||||
Fail,
|
||||
Info,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for CheckStatus {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
CheckStatus::Pass => write!(f, "✓"),
|
||||
CheckStatus::Warning => write!(f, "⚠"),
|
||||
CheckStatus::Fail => write!(f, "✗"),
|
||||
CheckStatus::Info => write!(f, "ℹ"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A single diagnostic check result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DiagnosticCheck {
|
||||
pub name: String,
|
||||
pub category: String,
|
||||
pub status: CheckStatus,
|
||||
pub message: String,
|
||||
pub recommendation: Option<String>,
|
||||
pub auto_fixable: bool,
|
||||
}
|
||||
|
||||
/// Complete diagnostic report
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DiagnosticReport {
|
||||
pub timestamp: String,
|
||||
pub system_info: SystemInfo,
|
||||
pub checks: Vec<DiagnosticCheck>,
|
||||
pub recommendations: Vec<String>,
|
||||
pub optimal_config: OptimalConfig,
|
||||
}
|
||||
|
||||
/// System information gathered during diagnosis
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SystemInfo {
|
||||
pub os: String,
|
||||
pub arch: String,
|
||||
pub cpu_count: usize,
|
||||
pub cpu_brand: String,
|
||||
pub total_memory_mb: u64,
|
||||
pub available_memory_mb: u64,
|
||||
pub simd_features: SimdFeatures,
|
||||
}
|
||||
|
||||
/// SIMD feature detection results
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SimdFeatures {
|
||||
pub sse2: bool,
|
||||
pub sse4_1: bool,
|
||||
pub sse4_2: bool,
|
||||
pub avx: bool,
|
||||
pub avx2: bool,
|
||||
pub avx512f: bool,
|
||||
pub neon: bool,
|
||||
pub best_available: String,
|
||||
}
|
||||
|
||||
/// Optimal configuration recommendations
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OptimalConfig {
|
||||
pub batch_size: usize,
|
||||
pub worker_threads: usize,
|
||||
pub simd_backend: String,
|
||||
pub memory_limit_mb: u64,
|
||||
pub preprocessing_mode: String,
|
||||
pub cache_enabled: bool,
|
||||
pub cache_size_mb: u64,
|
||||
}
|
||||
|
||||
/// Execute the doctor command
|
||||
pub async fn execute(args: DoctorArgs) -> Result<()> {
|
||||
if !args.json {
|
||||
println!("🩺 SciPix Doctor - Environment Analysis\n");
|
||||
println!("═══════════════════════════════════════════════════════════\n");
|
||||
}
|
||||
|
||||
let mut checks = Vec::new();
|
||||
|
||||
// Gather system information
|
||||
let system_info = gather_system_info();
|
||||
|
||||
// Run checks based on category
|
||||
match args.check {
|
||||
CheckCategory::All => {
|
||||
checks.extend(check_cpu(&system_info, args.verbose));
|
||||
checks.extend(check_memory(&system_info, args.verbose));
|
||||
checks.extend(check_dependencies(args.verbose));
|
||||
checks.extend(check_config(&args.config_path, args.verbose));
|
||||
checks.extend(check_network(args.verbose).await);
|
||||
}
|
||||
CheckCategory::Cpu => {
|
||||
checks.extend(check_cpu(&system_info, args.verbose));
|
||||
}
|
||||
CheckCategory::Memory => {
|
||||
checks.extend(check_memory(&system_info, args.verbose));
|
||||
}
|
||||
CheckCategory::Config => {
|
||||
checks.extend(check_config(&args.config_path, args.verbose));
|
||||
}
|
||||
CheckCategory::Deps => {
|
||||
checks.extend(check_dependencies(args.verbose));
|
||||
}
|
||||
CheckCategory::Network => {
|
||||
checks.extend(check_network(args.verbose).await);
|
||||
}
|
||||
}
|
||||
|
||||
// Generate optimal configuration
|
||||
let optimal_config = generate_optimal_config(&system_info);
|
||||
|
||||
// Collect recommendations
|
||||
let recommendations: Vec<String> = checks
|
||||
.iter()
|
||||
.filter_map(|c| c.recommendation.clone())
|
||||
.collect();
|
||||
|
||||
// Create report
|
||||
let report = DiagnosticReport {
|
||||
timestamp: chrono::Utc::now().to_rfc3339(),
|
||||
system_info: system_info.clone(),
|
||||
checks: checks.clone(),
|
||||
recommendations: recommendations.clone(),
|
||||
optimal_config: optimal_config.clone(),
|
||||
};
|
||||
|
||||
if args.json {
|
||||
println!("{}", serde_json::to_string_pretty(&report)?);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Print system info
|
||||
print_system_info(&system_info);
|
||||
|
||||
// Print check results
|
||||
print_check_results(&checks);
|
||||
|
||||
// Print recommendations
|
||||
if !recommendations.is_empty() {
|
||||
println!("\n📋 Recommendations:");
|
||||
println!("───────────────────────────────────────────────────────────");
|
||||
for (i, rec) in recommendations.iter().enumerate() {
|
||||
println!(" {}. {}", i + 1, rec);
|
||||
}
|
||||
}
|
||||
|
||||
// Print optimal configuration
|
||||
print_optimal_config(&optimal_config);
|
||||
|
||||
// Apply fixes if requested
|
||||
if args.fix {
|
||||
apply_fixes(&checks).await?;
|
||||
}
|
||||
|
||||
// Print summary
|
||||
print_summary(&checks);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn gather_system_info() -> SystemInfo {
|
||||
let cpu_count = num_cpus::get();
|
||||
|
||||
// Get CPU brand string
|
||||
let cpu_brand = get_cpu_brand();
|
||||
|
||||
// Get memory info
|
||||
let (total_memory_mb, available_memory_mb) = get_memory_info();
|
||||
|
||||
// Detect SIMD features
|
||||
let simd_features = detect_simd_features();
|
||||
|
||||
SystemInfo {
|
||||
os: std::env::consts::OS.to_string(),
|
||||
arch: std::env::consts::ARCH.to_string(),
|
||||
cpu_count,
|
||||
cpu_brand,
|
||||
total_memory_mb,
|
||||
available_memory_mb,
|
||||
simd_features,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_cpu_brand() -> String {
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
{
|
||||
if let Some(brand) = get_x86_cpu_brand() {
|
||||
return brand;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback
|
||||
format!("{} processor", std::env::consts::ARCH)
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
fn get_x86_cpu_brand() -> Option<String> {
|
||||
// Try to read from /proc/cpuinfo on Linux
|
||||
if let Ok(cpuinfo) = std::fs::read_to_string("/proc/cpuinfo") {
|
||||
for line in cpuinfo.lines() {
|
||||
if line.starts_with("model name") {
|
||||
if let Some(brand) = line.split(':').nth(1) {
|
||||
return Some(brand.trim().to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "x86_64"))]
|
||||
fn get_x86_cpu_brand() -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
fn get_memory_info() -> (u64, u64) {
|
||||
// Try to read from /proc/meminfo on Linux
|
||||
if let Ok(meminfo) = std::fs::read_to_string("/proc/meminfo") {
|
||||
let mut total = 0u64;
|
||||
let mut available = 0u64;
|
||||
|
||||
for line in meminfo.lines() {
|
||||
if line.starts_with("MemTotal:") {
|
||||
if let Some(kb) = parse_meminfo_value(line) {
|
||||
total = kb / 1024; // Convert to MB
|
||||
}
|
||||
} else if line.starts_with("MemAvailable:") {
|
||||
if let Some(kb) = parse_meminfo_value(line) {
|
||||
available = kb / 1024; // Convert to MB
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if total > 0 {
|
||||
return (total, available);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback values
|
||||
(8192, 4096)
|
||||
}
|
||||
|
||||
fn parse_meminfo_value(line: &str) -> Option<u64> {
|
||||
line.split_whitespace().nth(1).and_then(|s| s.parse().ok())
|
||||
}
|
||||
|
||||
fn detect_simd_features() -> SimdFeatures {
|
||||
let mut features = SimdFeatures {
|
||||
sse2: false,
|
||||
sse4_1: false,
|
||||
sse4_2: false,
|
||||
avx: false,
|
||||
avx2: false,
|
||||
avx512f: false,
|
||||
neon: false,
|
||||
best_available: "scalar".to_string(),
|
||||
};
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
{
|
||||
features.sse2 = is_x86_feature_detected!("sse2");
|
||||
features.sse4_1 = is_x86_feature_detected!("sse4.1");
|
||||
features.sse4_2 = is_x86_feature_detected!("sse4.2");
|
||||
features.avx = is_x86_feature_detected!("avx");
|
||||
features.avx2 = is_x86_feature_detected!("avx2");
|
||||
features.avx512f = is_x86_feature_detected!("avx512f");
|
||||
|
||||
if features.avx512f {
|
||||
features.best_available = "AVX-512".to_string();
|
||||
} else if features.avx2 {
|
||||
features.best_available = "AVX2".to_string();
|
||||
} else if features.avx {
|
||||
features.best_available = "AVX".to_string();
|
||||
} else if features.sse4_2 {
|
||||
features.best_available = "SSE4.2".to_string();
|
||||
} else if features.sse2 {
|
||||
features.best_available = "SSE2".to_string();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
{
|
||||
features.neon = true; // NEON is always available on AArch64
|
||||
features.best_available = "NEON".to_string();
|
||||
}
|
||||
|
||||
features
|
||||
}
|
||||
|
||||
fn check_cpu(system_info: &SystemInfo, verbose: bool) -> Vec<DiagnosticCheck> {
|
||||
let mut checks = Vec::new();
|
||||
|
||||
// CPU count check
|
||||
let cpu_status = if system_info.cpu_count >= 8 {
|
||||
CheckStatus::Pass
|
||||
} else if system_info.cpu_count >= 4 {
|
||||
CheckStatus::Warning
|
||||
} else {
|
||||
CheckStatus::Fail
|
||||
};
|
||||
|
||||
checks.push(DiagnosticCheck {
|
||||
name: "CPU Cores".to_string(),
|
||||
category: "CPU".to_string(),
|
||||
status: cpu_status,
|
||||
message: format!("{} cores detected", system_info.cpu_count),
|
||||
recommendation: if system_info.cpu_count < 4 {
|
||||
Some(
|
||||
"Consider running on a machine with more CPU cores for better batch processing"
|
||||
.to_string(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
auto_fixable: false,
|
||||
});
|
||||
|
||||
// SIMD check
|
||||
let simd_status = match system_info.simd_features.best_available.as_str() {
|
||||
"AVX-512" | "AVX2" => CheckStatus::Pass,
|
||||
"AVX" | "SSE4.2" | "NEON" => CheckStatus::Warning,
|
||||
_ => CheckStatus::Fail,
|
||||
};
|
||||
|
||||
checks.push(DiagnosticCheck {
|
||||
name: "SIMD Support".to_string(),
|
||||
category: "CPU".to_string(),
|
||||
status: simd_status,
|
||||
message: format!(
|
||||
"Best SIMD: {} (SSE2: {}, AVX: {}, AVX2: {}, AVX-512: {})",
|
||||
system_info.simd_features.best_available,
|
||||
if system_info.simd_features.sse2 {
|
||||
"✓"
|
||||
} else {
|
||||
"✗"
|
||||
},
|
||||
if system_info.simd_features.avx {
|
||||
"✓"
|
||||
} else {
|
||||
"✗"
|
||||
},
|
||||
if system_info.simd_features.avx2 {
|
||||
"✓"
|
||||
} else {
|
||||
"✗"
|
||||
},
|
||||
if system_info.simd_features.avx512f {
|
||||
"✓"
|
||||
} else {
|
||||
"✗"
|
||||
},
|
||||
),
|
||||
recommendation: if simd_status == CheckStatus::Fail {
|
||||
Some("Upgrade to a CPU with AVX2 support for 4x faster preprocessing".to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
auto_fixable: false,
|
||||
});
|
||||
|
||||
if verbose {
|
||||
checks.push(DiagnosticCheck {
|
||||
name: "CPU Brand".to_string(),
|
||||
category: "CPU".to_string(),
|
||||
status: CheckStatus::Info,
|
||||
message: system_info.cpu_brand.clone(),
|
||||
recommendation: None,
|
||||
auto_fixable: false,
|
||||
});
|
||||
}
|
||||
|
||||
checks
|
||||
}
|
||||
|
||||
fn check_memory(system_info: &SystemInfo, verbose: bool) -> Vec<DiagnosticCheck> {
|
||||
let mut checks = Vec::new();
|
||||
|
||||
// Total memory check
|
||||
let mem_status = if system_info.total_memory_mb >= 16384 {
|
||||
CheckStatus::Pass
|
||||
} else if system_info.total_memory_mb >= 8192 {
|
||||
CheckStatus::Warning
|
||||
} else {
|
||||
CheckStatus::Fail
|
||||
};
|
||||
|
||||
checks.push(DiagnosticCheck {
|
||||
name: "Total Memory".to_string(),
|
||||
category: "Memory".to_string(),
|
||||
status: mem_status,
|
||||
message: format!("{} MB total", system_info.total_memory_mb),
|
||||
recommendation: if system_info.total_memory_mb < 8192 {
|
||||
Some("Consider upgrading to at least 8GB RAM for optimal batch processing".to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
auto_fixable: false,
|
||||
});
|
||||
|
||||
// Available memory check
|
||||
let avail_ratio = system_info.available_memory_mb as f64 / system_info.total_memory_mb as f64;
|
||||
let avail_status = if avail_ratio >= 0.5 {
|
||||
CheckStatus::Pass
|
||||
} else if avail_ratio >= 0.25 {
|
||||
CheckStatus::Warning
|
||||
} else {
|
||||
CheckStatus::Fail
|
||||
};
|
||||
|
||||
checks.push(DiagnosticCheck {
|
||||
name: "Available Memory".to_string(),
|
||||
category: "Memory".to_string(),
|
||||
status: avail_status,
|
||||
message: format!(
|
||||
"{} MB available ({:.1}%)",
|
||||
system_info.available_memory_mb,
|
||||
avail_ratio * 100.0
|
||||
),
|
||||
recommendation: if avail_status == CheckStatus::Fail {
|
||||
Some("Close some applications to free up memory before batch processing".to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
auto_fixable: false,
|
||||
});
|
||||
|
||||
if verbose {
|
||||
// Memory per core
|
||||
let mem_per_core = system_info.total_memory_mb / system_info.cpu_count as u64;
|
||||
checks.push(DiagnosticCheck {
|
||||
name: "Memory per Core".to_string(),
|
||||
category: "Memory".to_string(),
|
||||
status: CheckStatus::Info,
|
||||
message: format!("{} MB/core", mem_per_core),
|
||||
recommendation: None,
|
||||
auto_fixable: false,
|
||||
});
|
||||
}
|
||||
|
||||
checks
|
||||
}
|
||||
|
||||
fn check_dependencies(verbose: bool) -> Vec<DiagnosticCheck> {
|
||||
let mut checks = Vec::new();
|
||||
|
||||
// Check for ONNX Runtime
|
||||
let onnx_status = check_onnx_runtime();
|
||||
checks.push(DiagnosticCheck {
|
||||
name: "ONNX Runtime".to_string(),
|
||||
category: "Dependencies".to_string(),
|
||||
status: if onnx_status.0 {
|
||||
CheckStatus::Pass
|
||||
} else {
|
||||
CheckStatus::Warning
|
||||
},
|
||||
message: onnx_status.1.clone(),
|
||||
recommendation: if !onnx_status.0 {
|
||||
Some(
|
||||
"Install ONNX Runtime for neural network acceleration: https://onnxruntime.ai/"
|
||||
.to_string(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
auto_fixable: false,
|
||||
});
|
||||
|
||||
// Check for image processing libraries
|
||||
checks.push(DiagnosticCheck {
|
||||
name: "Image Processing".to_string(),
|
||||
category: "Dependencies".to_string(),
|
||||
status: CheckStatus::Pass,
|
||||
message: "image crate available (built-in)".to_string(),
|
||||
recommendation: None,
|
||||
auto_fixable: false,
|
||||
});
|
||||
|
||||
// Check for OpenSSL (for HTTPS)
|
||||
let openssl_available = std::process::Command::new("openssl")
|
||||
.arg("version")
|
||||
.output()
|
||||
.map(|o| o.status.success())
|
||||
.unwrap_or(false);
|
||||
|
||||
checks.push(DiagnosticCheck {
|
||||
name: "OpenSSL".to_string(),
|
||||
category: "Dependencies".to_string(),
|
||||
status: if openssl_available {
|
||||
CheckStatus::Pass
|
||||
} else {
|
||||
CheckStatus::Warning
|
||||
},
|
||||
message: if openssl_available {
|
||||
"OpenSSL available for HTTPS".to_string()
|
||||
} else {
|
||||
"OpenSSL not found".to_string()
|
||||
},
|
||||
recommendation: if !openssl_available {
|
||||
Some("Install OpenSSL for secure API communication".to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
auto_fixable: false,
|
||||
});
|
||||
|
||||
if verbose {
|
||||
// Check Rust version
|
||||
if let Ok(output) = std::process::Command::new("rustc")
|
||||
.arg("--version")
|
||||
.output()
|
||||
{
|
||||
let version = String::from_utf8_lossy(&output.stdout);
|
||||
checks.push(DiagnosticCheck {
|
||||
name: "Rust Compiler".to_string(),
|
||||
category: "Dependencies".to_string(),
|
||||
status: CheckStatus::Info,
|
||||
message: version.trim().to_string(),
|
||||
recommendation: None,
|
||||
auto_fixable: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
checks
|
||||
}
|
||||
|
||||
fn check_onnx_runtime() -> (bool, String) {
|
||||
// Check for ONNX runtime shared library
|
||||
let lib_paths = [
|
||||
"/usr/lib/libonnxruntime.so",
|
||||
"/usr/local/lib/libonnxruntime.so",
|
||||
"/opt/onnxruntime/lib/libonnxruntime.so",
|
||||
];
|
||||
|
||||
for path in &lib_paths {
|
||||
if std::path::Path::new(path).exists() {
|
||||
return (true, format!("Found at {}", path));
|
||||
}
|
||||
}
|
||||
|
||||
// Check via environment variable
|
||||
if std::env::var("ORT_DYLIB_PATH").is_ok() {
|
||||
return (true, "Configured via ORT_DYLIB_PATH".to_string());
|
||||
}
|
||||
|
||||
(
|
||||
false,
|
||||
"Not found (optional for ONNX acceleration)".to_string(),
|
||||
)
|
||||
}
|
||||
|
||||
fn check_config(config_path: &Option<PathBuf>, verbose: bool) -> Vec<DiagnosticCheck> {
|
||||
let mut checks = Vec::new();
|
||||
|
||||
// Check for config file
|
||||
let config_locations = [
|
||||
config_path.clone(),
|
||||
Some(PathBuf::from("scipix.toml")),
|
||||
Some(PathBuf::from("config/scipix.toml")),
|
||||
dirs::config_dir().map(|p| p.join("scipix/config.toml")),
|
||||
];
|
||||
|
||||
let mut found_config = false;
|
||||
for loc in config_locations.iter().flatten() {
|
||||
if loc.exists() {
|
||||
checks.push(DiagnosticCheck {
|
||||
name: "Configuration File".to_string(),
|
||||
category: "Config".to_string(),
|
||||
status: CheckStatus::Pass,
|
||||
message: format!("Found at {}", loc.display()),
|
||||
recommendation: None,
|
||||
auto_fixable: false,
|
||||
});
|
||||
found_config = true;
|
||||
|
||||
// Validate config content
|
||||
if let Ok(content) = std::fs::read_to_string(loc) {
|
||||
if content.contains("[api]") || content.contains("[processing]") {
|
||||
checks.push(DiagnosticCheck {
|
||||
name: "Config Validity".to_string(),
|
||||
category: "Config".to_string(),
|
||||
status: CheckStatus::Pass,
|
||||
message: "Configuration file is valid".to_string(),
|
||||
recommendation: None,
|
||||
auto_fixable: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if !found_config {
|
||||
checks.push(DiagnosticCheck {
|
||||
name: "Configuration File".to_string(),
|
||||
category: "Config".to_string(),
|
||||
status: CheckStatus::Info,
|
||||
message: "No configuration file found (using defaults)".to_string(),
|
||||
recommendation: Some("Create a scipix.toml for custom settings".to_string()),
|
||||
auto_fixable: true,
|
||||
});
|
||||
}
|
||||
|
||||
// Check environment variables
|
||||
let env_vars = [
|
||||
("SCIPIX_API_KEY", "API authentication"),
|
||||
("SCIPIX_MODEL_PATH", "Custom model path"),
|
||||
("SCIPIX_CACHE_DIR", "Cache directory"),
|
||||
];
|
||||
|
||||
for (var, desc) in &env_vars {
|
||||
let status = if std::env::var(var).is_ok() {
|
||||
CheckStatus::Pass
|
||||
} else {
|
||||
CheckStatus::Info
|
||||
};
|
||||
|
||||
if verbose || status == CheckStatus::Pass {
|
||||
checks.push(DiagnosticCheck {
|
||||
name: format!("Env: {}", var),
|
||||
category: "Config".to_string(),
|
||||
status,
|
||||
message: if status == CheckStatus::Pass {
|
||||
format!("{} configured", desc)
|
||||
} else {
|
||||
format!("{} not set (optional)", desc)
|
||||
},
|
||||
recommendation: None,
|
||||
auto_fixable: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
checks
|
||||
}
|
||||
|
||||
async fn check_network(verbose: bool) -> Vec<DiagnosticCheck> {
|
||||
let mut checks = Vec::new();
|
||||
|
||||
// Check localhost binding
|
||||
let localhost_available = tokio::net::TcpListener::bind("127.0.0.1:0").await.is_ok();
|
||||
|
||||
checks.push(DiagnosticCheck {
|
||||
name: "Localhost Binding".to_string(),
|
||||
category: "Network".to_string(),
|
||||
status: if localhost_available {
|
||||
CheckStatus::Pass
|
||||
} else {
|
||||
CheckStatus::Fail
|
||||
},
|
||||
message: if localhost_available {
|
||||
"Can bind to localhost".to_string()
|
||||
} else {
|
||||
"Cannot bind to localhost".to_string()
|
||||
},
|
||||
recommendation: if !localhost_available {
|
||||
Some("Check firewall settings and port availability".to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
auto_fixable: false,
|
||||
});
|
||||
|
||||
// Check common ports
|
||||
let ports_to_check = [(8080, "API server"), (3000, "Alternative API")];
|
||||
|
||||
for (port, desc) in &ports_to_check {
|
||||
let available = tokio::net::TcpListener::bind(format!("127.0.0.1:{}", port))
|
||||
.await
|
||||
.is_ok();
|
||||
|
||||
if verbose || !available {
|
||||
checks.push(DiagnosticCheck {
|
||||
name: format!("Port {}", port),
|
||||
category: "Network".to_string(),
|
||||
status: if available {
|
||||
CheckStatus::Pass
|
||||
} else {
|
||||
CheckStatus::Warning
|
||||
},
|
||||
message: if available {
|
||||
format!("Port {} ({}) available", port, desc)
|
||||
} else {
|
||||
format!("Port {} ({}) in use", port, desc)
|
||||
},
|
||||
recommendation: if !available {
|
||||
Some(format!(
|
||||
"Free port {} or use --port to specify alternative",
|
||||
port
|
||||
))
|
||||
} else {
|
||||
None
|
||||
},
|
||||
auto_fixable: false,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
checks
|
||||
}
|
||||
|
||||
fn generate_optimal_config(system_info: &SystemInfo) -> OptimalConfig {
|
||||
// Calculate optimal batch size based on memory
|
||||
let batch_size = if system_info.available_memory_mb >= 8192 {
|
||||
32
|
||||
} else if system_info.available_memory_mb >= 4096 {
|
||||
16
|
||||
} else if system_info.available_memory_mb >= 2048 {
|
||||
8
|
||||
} else {
|
||||
4
|
||||
};
|
||||
|
||||
// Calculate worker threads (leave some headroom)
|
||||
let worker_threads = (system_info.cpu_count as f64 * 0.75).ceil() as usize;
|
||||
let worker_threads = worker_threads.max(2);
|
||||
|
||||
// Determine SIMD backend
|
||||
let simd_backend = system_info.simd_features.best_available.clone();
|
||||
|
||||
// Memory limit (use 60% of available)
|
||||
let memory_limit_mb = (system_info.available_memory_mb as f64 * 0.6) as u64;
|
||||
|
||||
// Preprocessing mode based on SIMD
|
||||
let preprocessing_mode = if system_info.simd_features.avx2 || system_info.simd_features.neon {
|
||||
"simd_optimized".to_string()
|
||||
} else if system_info.simd_features.sse4_2 {
|
||||
"simd_basic".to_string()
|
||||
} else {
|
||||
"scalar".to_string()
|
||||
};
|
||||
|
||||
// Cache settings
|
||||
let cache_enabled = system_info.available_memory_mb >= 2048;
|
||||
let cache_size_mb = if cache_enabled {
|
||||
(system_info.available_memory_mb as f64 * 0.1) as u64
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
OptimalConfig {
|
||||
batch_size,
|
||||
worker_threads,
|
||||
simd_backend,
|
||||
memory_limit_mb,
|
||||
preprocessing_mode,
|
||||
cache_enabled,
|
||||
cache_size_mb,
|
||||
}
|
||||
}
|
||||
|
||||
fn print_system_info(info: &SystemInfo) {
|
||||
println!("📊 System Information:");
|
||||
println!("───────────────────────────────────────────────────────────");
|
||||
println!(" OS: {} ({})", info.os, info.arch);
|
||||
println!(" CPU: {}", info.cpu_brand);
|
||||
println!(" Cores: {}", info.cpu_count);
|
||||
println!(
|
||||
" Memory: {} MB total, {} MB available",
|
||||
info.total_memory_mb, info.available_memory_mb
|
||||
);
|
||||
println!(" Best SIMD: {}", info.simd_features.best_available);
|
||||
println!();
|
||||
}
|
||||
|
||||
fn print_check_results(checks: &[DiagnosticCheck]) {
|
||||
println!("🔍 Diagnostic Checks:");
|
||||
println!("───────────────────────────────────────────────────────────");
|
||||
|
||||
let mut current_category = String::new();
|
||||
for check in checks {
|
||||
if check.category != current_category {
|
||||
if !current_category.is_empty() {
|
||||
println!();
|
||||
}
|
||||
println!(" [{}]", check.category);
|
||||
current_category = check.category.clone();
|
||||
}
|
||||
|
||||
let status_color = match check.status {
|
||||
CheckStatus::Pass => "\x1b[32m", // Green
|
||||
CheckStatus::Warning => "\x1b[33m", // Yellow
|
||||
CheckStatus::Fail => "\x1b[31m", // Red
|
||||
CheckStatus::Info => "\x1b[36m", // Cyan
|
||||
};
|
||||
|
||||
println!(
|
||||
" {}{}\x1b[0m {} - {}",
|
||||
status_color, check.status, check.name, check.message
|
||||
);
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
fn print_optimal_config(config: &OptimalConfig) {
|
||||
println!("\n⚙️ Optimal Configuration:");
|
||||
println!("───────────────────────────────────────────────────────────");
|
||||
println!(" batch_size: {}", config.batch_size);
|
||||
println!(" worker_threads: {}", config.worker_threads);
|
||||
println!(" simd_backend: {}", config.simd_backend);
|
||||
println!(" memory_limit: {} MB", config.memory_limit_mb);
|
||||
println!(" preprocessing: {}", config.preprocessing_mode);
|
||||
println!(" cache_enabled: {}", config.cache_enabled);
|
||||
if config.cache_enabled {
|
||||
println!(" cache_size: {} MB", config.cache_size_mb);
|
||||
}
|
||||
|
||||
println!("\n 📝 Example configuration (scipix.toml):");
|
||||
println!(" ─────────────────────────────────────────");
|
||||
println!(" [processing]");
|
||||
println!(" batch_size = {}", config.batch_size);
|
||||
println!(" worker_threads = {}", config.worker_threads);
|
||||
println!(" simd_backend = \"{}\"", config.simd_backend);
|
||||
println!(" memory_limit_mb = {}", config.memory_limit_mb);
|
||||
println!();
|
||||
println!(" [cache]");
|
||||
println!(" enabled = {}", config.cache_enabled);
|
||||
println!(" size_mb = {}", config.cache_size_mb);
|
||||
}
|
||||
|
||||
fn print_summary(checks: &[DiagnosticCheck]) {
|
||||
let pass_count = checks
|
||||
.iter()
|
||||
.filter(|c| c.status == CheckStatus::Pass)
|
||||
.count();
|
||||
let warn_count = checks
|
||||
.iter()
|
||||
.filter(|c| c.status == CheckStatus::Warning)
|
||||
.count();
|
||||
let fail_count = checks
|
||||
.iter()
|
||||
.filter(|c| c.status == CheckStatus::Fail)
|
||||
.count();
|
||||
|
||||
println!("\n═══════════════════════════════════════════════════════════");
|
||||
println!(
|
||||
"📋 Summary: {} passed, {} warnings, {} failed",
|
||||
pass_count, warn_count, fail_count
|
||||
);
|
||||
|
||||
if fail_count > 0 {
|
||||
println!("\n⚠️ Some checks failed. Review recommendations above.");
|
||||
} else if warn_count > 0 {
|
||||
println!("\n✓ System is functional with some areas for improvement.");
|
||||
} else {
|
||||
println!("\n✅ System is optimally configured for SciPix!");
|
||||
}
|
||||
}
|
||||
|
||||
async fn apply_fixes(checks: &[DiagnosticCheck]) -> Result<()> {
|
||||
println!("\n🔧 Applying automatic fixes...");
|
||||
println!("───────────────────────────────────────────────────────────");
|
||||
|
||||
let fixable: Vec<_> = checks.iter().filter(|c| c.auto_fixable).collect();
|
||||
|
||||
if fixable.is_empty() {
|
||||
println!(" No automatic fixes available.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for check in fixable {
|
||||
println!(" Fixing: {}", check.name);
|
||||
|
||||
if check.name == "Configuration File" {
|
||||
// Create default config file
|
||||
let config_content = r#"# SciPix Configuration
|
||||
# Generated by scipix doctor --fix
|
||||
|
||||
[processing]
|
||||
batch_size = 16
|
||||
worker_threads = 4
|
||||
simd_backend = "auto"
|
||||
memory_limit_mb = 4096
|
||||
|
||||
[cache]
|
||||
enabled = true
|
||||
size_mb = 256
|
||||
|
||||
[api]
|
||||
host = "127.0.0.1"
|
||||
port = 8080
|
||||
timeout_seconds = 30
|
||||
|
||||
[logging]
|
||||
level = "info"
|
||||
format = "pretty"
|
||||
"#;
|
||||
|
||||
// Create config directory if needed
|
||||
let config_path = PathBuf::from("config");
|
||||
if !config_path.exists() {
|
||||
std::fs::create_dir_all(&config_path)?;
|
||||
}
|
||||
|
||||
let config_file = config_path.join("scipix.toml");
|
||||
std::fs::write(&config_file, config_content)?;
|
||||
println!(" ✓ Created {}", config_file.display());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
806
examples/scipix/src/cli/commands/mcp.rs
Normal file
806
examples/scipix/src/cli/commands/mcp.rs
Normal file
@@ -0,0 +1,806 @@
|
||||
//! MCP (Model Context Protocol) Server Implementation for SciPix
|
||||
//!
|
||||
//! Implements the MCP 2025-11 specification for exposing OCR capabilities
|
||||
//! as tools that can be discovered and invoked by AI hosts.
|
||||
//!
|
||||
//! ## Usage
|
||||
//! ```bash
|
||||
//! scipix-cli mcp
|
||||
//! ```
|
||||
//!
|
||||
//! ## Protocol
|
||||
//! Uses JSON-RPC 2.0 over STDIO for communication.
|
||||
|
||||
use clap::Args;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::io::{self, BufRead, Write};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// MCP Server Arguments
|
||||
#[derive(Args, Debug, Clone)]
|
||||
pub struct McpArgs {
|
||||
/// Enable debug logging for MCP messages
|
||||
#[arg(long, help = "Enable debug logging")]
|
||||
pub debug: bool,
|
||||
|
||||
/// Custom model path for OCR
|
||||
#[arg(long, help = "Path to ONNX models directory")]
|
||||
pub models_dir: Option<PathBuf>,
|
||||
}
|
||||
|
||||
/// JSON-RPC 2.0 Request
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct JsonRpcRequest {
|
||||
#[allow(dead_code)]
|
||||
jsonrpc: String,
|
||||
id: Option<Value>,
|
||||
method: String,
|
||||
params: Option<Value>,
|
||||
}
|
||||
|
||||
/// JSON-RPC 2.0 Response
|
||||
#[derive(Debug, Serialize)]
|
||||
struct JsonRpcResponse {
|
||||
jsonrpc: String,
|
||||
id: Value,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
result: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
error: Option<JsonRpcError>,
|
||||
}
|
||||
|
||||
/// JSON-RPC 2.0 Error
|
||||
#[derive(Debug, Serialize)]
|
||||
struct JsonRpcError {
|
||||
code: i32,
|
||||
message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
data: Option<Value>,
|
||||
}
|
||||
|
||||
/// MCP Server Info
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ServerInfo {
|
||||
name: String,
|
||||
version: String,
|
||||
}
|
||||
|
||||
/// MCP Server Capabilities
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ServerCapabilities {
|
||||
tools: ToolsCapability,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
resources: Option<ResourcesCapability>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ToolsCapability {
|
||||
#[serde(rename = "listChanged")]
|
||||
list_changed: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ResourcesCapability {
|
||||
subscribe: bool,
|
||||
#[serde(rename = "listChanged")]
|
||||
list_changed: bool,
|
||||
}
|
||||
|
||||
/// MCP Tool Definition
|
||||
#[derive(Debug, Serialize)]
|
||||
struct Tool {
|
||||
name: String,
|
||||
description: String,
|
||||
#[serde(rename = "inputSchema")]
|
||||
input_schema: Value,
|
||||
}
|
||||
|
||||
/// Tool call result
|
||||
#[derive(Debug, Serialize)]
|
||||
#[allow(dead_code)]
|
||||
struct ToolResult {
|
||||
content: Vec<ContentBlock>,
|
||||
#[serde(rename = "isError", skip_serializing_if = "Option::is_none")]
|
||||
is_error: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[allow(dead_code)]
|
||||
struct ContentBlock {
|
||||
#[serde(rename = "type")]
|
||||
content_type: String,
|
||||
text: String,
|
||||
}
|
||||
|
||||
impl JsonRpcResponse {
|
||||
fn success(id: Value, result: Value) -> Self {
|
||||
Self {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id,
|
||||
result: Some(result),
|
||||
error: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn error(id: Value, code: i32, message: &str) -> Self {
|
||||
Self {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id,
|
||||
result: None,
|
||||
error: Some(JsonRpcError {
|
||||
code,
|
||||
message: message.to_string(),
|
||||
data: None,
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// MCP Server state
|
||||
struct McpServer {
|
||||
debug: bool,
|
||||
#[allow(dead_code)]
|
||||
models_dir: Option<PathBuf>,
|
||||
}
|
||||
|
||||
impl McpServer {
|
||||
fn new(args: &McpArgs) -> Self {
|
||||
Self {
|
||||
debug: args.debug,
|
||||
models_dir: args.models_dir.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get server info for initialization
|
||||
fn server_info(&self) -> ServerInfo {
|
||||
ServerInfo {
|
||||
name: "scipix-mcp".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get server capabilities
|
||||
fn capabilities(&self) -> ServerCapabilities {
|
||||
ServerCapabilities {
|
||||
tools: ToolsCapability {
|
||||
list_changed: false,
|
||||
},
|
||||
resources: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Define available tools with examples following Anthropic best practices
|
||||
/// See: https://www.anthropic.com/engineering/advanced-tool-use
|
||||
fn get_tools(&self) -> Vec<Tool> {
|
||||
vec![
|
||||
Tool {
|
||||
name: "ocr_image".to_string(),
|
||||
description: r#"Process an image file with OCR to extract text and mathematical formulas.
|
||||
|
||||
WHEN TO USE: Use this tool when you have an image file path containing text, equations,
|
||||
or mathematical notation that needs to be converted to a machine-readable format.
|
||||
|
||||
EXAMPLES:
|
||||
- Extract LaTeX from a photo of a math equation: {"image_path": "equation.png", "format": "latex"}
|
||||
- Get plain text from a document scan: {"image_path": "document.jpg", "format": "text"}
|
||||
- Convert handwritten math to AsciiMath: {"image_path": "notes.png", "format": "asciimath"}
|
||||
|
||||
RETURNS: JSON with the recognized content, confidence score (0-1), and processing metadata."#.to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image_path": {
|
||||
"type": "string",
|
||||
"description": "Absolute or relative path to image file (PNG, JPG, JPEG, GIF, BMP, TIFF supported)"
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["latex", "text", "mathml", "asciimath"],
|
||||
"default": "latex",
|
||||
"description": "Output format: 'latex' for mathematical notation, 'text' for plain text, 'mathml' for XML, 'asciimath' for simple notation"
|
||||
}
|
||||
},
|
||||
"required": ["image_path"],
|
||||
"examples": [
|
||||
{"image_path": "/path/to/equation.png", "format": "latex"},
|
||||
{"image_path": "document.jpg", "format": "text"}
|
||||
]
|
||||
}),
|
||||
},
|
||||
Tool {
|
||||
name: "ocr_base64".to_string(),
|
||||
description: r#"Process a base64-encoded image with OCR. Use when image data is inline rather than a file.
|
||||
|
||||
WHEN TO USE: Use this tool when you have image data as a base64 string (e.g., from an API
|
||||
response, clipboard, or embedded in a document) rather than a file path.
|
||||
|
||||
EXAMPLES:
|
||||
- Process clipboard image: {"image_data": "iVBORw0KGgo...", "format": "latex"}
|
||||
- Extract text from API response image: {"image_data": "<base64_string>", "format": "text"}
|
||||
|
||||
NOTE: The base64 string should not include the data URI prefix (e.g., "data:image/png;base64,")."#.to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image_data": {
|
||||
"type": "string",
|
||||
"description": "Base64-encoded image data (without data URI prefix)"
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["latex", "text", "mathml", "asciimath"],
|
||||
"default": "latex",
|
||||
"description": "Output format for recognized content"
|
||||
}
|
||||
},
|
||||
"required": ["image_data"]
|
||||
}),
|
||||
},
|
||||
Tool {
|
||||
name: "batch_ocr".to_string(),
|
||||
description: r#"Process multiple images in a directory with OCR. Efficient for bulk operations.
|
||||
|
||||
WHEN TO USE: Use this tool when you need to process 3+ images in the same directory.
|
||||
For 1-2 images, use ocr_image instead for simpler results.
|
||||
|
||||
EXAMPLES:
|
||||
- Process all PNGs in a folder: {"directory": "./images", "pattern": "*.png"}
|
||||
- Process specific equation images: {"directory": "/docs/math", "pattern": "eq_*.jpg"}
|
||||
- Get JSON results for all images: {"directory": ".", "pattern": "*.{png,jpg}", "format": "json"}
|
||||
|
||||
RETURNS: Array of results with file paths, recognized content, and confidence scores."#.to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"directory": {
|
||||
"type": "string",
|
||||
"description": "Directory path containing images to process"
|
||||
},
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"default": "*.png",
|
||||
"description": "Glob pattern to match files (e.g., '*.png', '*.{jpg,png}', 'equation_*.jpg')"
|
||||
},
|
||||
"format": {
|
||||
"type": "string",
|
||||
"enum": ["latex", "text", "json"],
|
||||
"default": "json",
|
||||
"description": "Output format: 'json' for structured results (recommended), 'latex' or 'text' for concatenated output"
|
||||
}
|
||||
},
|
||||
"required": ["directory"]
|
||||
}),
|
||||
},
|
||||
Tool {
|
||||
name: "preprocess_image".to_string(),
|
||||
description: r#"Apply preprocessing operations to optimize an image for OCR.
|
||||
|
||||
WHEN TO USE: Use this tool BEFORE ocr_image when dealing with:
|
||||
- Low contrast images (use threshold)
|
||||
- Large images that need resizing (use resize)
|
||||
- Color images (use grayscale for faster processing)
|
||||
- Noisy or blurry images (use denoise)
|
||||
|
||||
EXAMPLES:
|
||||
- Prepare scan for OCR: {"image_path": "scan.jpg", "output_path": "scan_clean.png", "operations": ["grayscale", "threshold"]}
|
||||
- Resize large image: {"image_path": "photo.jpg", "output_path": "photo_small.png", "operations": ["resize"], "target_width": 800}
|
||||
|
||||
WORKFLOW: preprocess_image -> ocr_image for best results on problematic images."#.to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image_path": {
|
||||
"type": "string",
|
||||
"description": "Path to input image file"
|
||||
},
|
||||
"output_path": {
|
||||
"type": "string",
|
||||
"description": "Path for preprocessed output image"
|
||||
},
|
||||
"operations": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string",
|
||||
"enum": ["grayscale", "resize", "threshold", "denoise", "deskew"]
|
||||
},
|
||||
"default": ["grayscale", "resize"],
|
||||
"description": "Operations to apply in order: grayscale (convert to B&W), resize (scale to target size), threshold (binarize), denoise (reduce noise), deskew (straighten)"
|
||||
},
|
||||
"target_width": {
|
||||
"type": "integer",
|
||||
"default": 640,
|
||||
"description": "Target width for resize (preserves aspect ratio)"
|
||||
},
|
||||
"target_height": {
|
||||
"type": "integer",
|
||||
"default": 480,
|
||||
"description": "Target height for resize (preserves aspect ratio)"
|
||||
}
|
||||
},
|
||||
"required": ["image_path", "output_path"]
|
||||
}),
|
||||
},
|
||||
Tool {
|
||||
name: "latex_to_mathml".to_string(),
|
||||
description: r#"Convert LaTeX mathematical notation to MathML XML format.
|
||||
|
||||
WHEN TO USE: Use this tool when you need MathML output from LaTeX, such as:
|
||||
- Generating accessible math content for web pages
|
||||
- Converting equations for screen readers
|
||||
- Integrating with systems that require MathML
|
||||
|
||||
EXAMPLES:
|
||||
- Convert fraction: {"latex": "\\frac{1}{2}"}
|
||||
- Convert integral: {"latex": "\\int_0^1 x^2 dx"}
|
||||
- Convert matrix: {"latex": "\\begin{pmatrix} a & b \\\\ c & d \\end{pmatrix}"}"#.to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"latex": {
|
||||
"type": "string",
|
||||
"description": "LaTeX expression to convert (with or without $ delimiters)"
|
||||
}
|
||||
},
|
||||
"required": ["latex"],
|
||||
"examples": [
|
||||
{"latex": "\\frac{a}{b}"},
|
||||
{"latex": "E = mc^2"}
|
||||
]
|
||||
}),
|
||||
},
|
||||
Tool {
|
||||
name: "benchmark_performance".to_string(),
|
||||
description: r#"Run performance benchmarks on the OCR pipeline and return timing metrics.
|
||||
|
||||
WHEN TO USE: Use this tool to:
|
||||
- Verify OCR performance on your system
|
||||
- Compare preprocessing options
|
||||
- Debug slow processing issues
|
||||
|
||||
EXAMPLES:
|
||||
- Quick performance check: {"iterations": 5}
|
||||
- Test specific image: {"image_path": "test.png", "iterations": 10}
|
||||
|
||||
RETURNS: Average processing times for grayscale, resize operations, and system info."#.to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"iterations": {
|
||||
"type": "integer",
|
||||
"default": 10,
|
||||
"minimum": 1,
|
||||
"maximum": 100,
|
||||
"description": "Number of benchmark iterations (higher = more accurate, slower)"
|
||||
},
|
||||
"image_path": {
|
||||
"type": "string",
|
||||
"description": "Optional: Path to test image (uses generated test image if not provided)"
|
||||
}
|
||||
}
|
||||
}),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
/// Handle incoming JSON-RPC request
|
||||
async fn handle_request(&self, request: JsonRpcRequest) -> JsonRpcResponse {
|
||||
let id = request.id.unwrap_or(Value::Null);
|
||||
|
||||
if self.debug {
|
||||
eprintln!("[MCP DEBUG] Method: {}", request.method);
|
||||
if let Some(ref params) = request.params {
|
||||
eprintln!(
|
||||
"[MCP DEBUG] Params: {}",
|
||||
serde_json::to_string_pretty(params).unwrap_or_default()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
match request.method.as_str() {
|
||||
"initialize" => self.handle_initialize(id, request.params),
|
||||
"initialized" => JsonRpcResponse::success(id, json!({})),
|
||||
"tools/list" => self.handle_tools_list(id),
|
||||
"tools/call" => self.handle_tools_call(id, request.params).await,
|
||||
"ping" => JsonRpcResponse::success(id, json!({})),
|
||||
"shutdown" => {
|
||||
std::process::exit(0);
|
||||
}
|
||||
_ => {
|
||||
JsonRpcResponse::error(id, -32601, &format!("Method not found: {}", request.method))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle initialize request
|
||||
fn handle_initialize(&self, id: Value, params: Option<Value>) -> JsonRpcResponse {
|
||||
if self.debug {
|
||||
if let Some(p) = ¶ms {
|
||||
eprintln!(
|
||||
"[MCP DEBUG] Client info: {}",
|
||||
serde_json::to_string_pretty(p).unwrap_or_default()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
JsonRpcResponse::success(
|
||||
id,
|
||||
json!({
|
||||
"protocolVersion": "2024-11-05",
|
||||
"serverInfo": self.server_info(),
|
||||
"capabilities": self.capabilities()
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Handle tools/list request
|
||||
fn handle_tools_list(&self, id: Value) -> JsonRpcResponse {
|
||||
JsonRpcResponse::success(
|
||||
id,
|
||||
json!({
|
||||
"tools": self.get_tools()
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Handle tools/call request
|
||||
async fn handle_tools_call(&self, id: Value, params: Option<Value>) -> JsonRpcResponse {
|
||||
let params = match params {
|
||||
Some(p) => p,
|
||||
None => return JsonRpcResponse::error(id, -32602, "Missing params"),
|
||||
};
|
||||
|
||||
let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
|
||||
let arguments = params.get("arguments").cloned().unwrap_or(json!({}));
|
||||
|
||||
if self.debug {
|
||||
eprintln!(
|
||||
"[MCP DEBUG] Tool call: {} with args: {}",
|
||||
tool_name, arguments
|
||||
);
|
||||
}
|
||||
|
||||
let result = match tool_name {
|
||||
"ocr_image" => self.call_ocr_image(&arguments).await,
|
||||
"ocr_base64" => self.call_ocr_base64(&arguments).await,
|
||||
"batch_ocr" => self.call_batch_ocr(&arguments).await,
|
||||
"preprocess_image" => self.call_preprocess_image(&arguments).await,
|
||||
"latex_to_mathml" => self.call_latex_to_mathml(&arguments).await,
|
||||
"benchmark_performance" => self.call_benchmark(&arguments).await,
|
||||
_ => Err(format!("Unknown tool: {}", tool_name)),
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(content) => JsonRpcResponse::success(
|
||||
id,
|
||||
json!({
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": content
|
||||
}]
|
||||
}),
|
||||
),
|
||||
Err(e) => JsonRpcResponse::success(
|
||||
id,
|
||||
json!({
|
||||
"content": [{
|
||||
"type": "text",
|
||||
"text": e
|
||||
}],
|
||||
"isError": true
|
||||
}),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
/// OCR image file
|
||||
async fn call_ocr_image(&self, args: &Value) -> Result<String, String> {
|
||||
let image_path = args
|
||||
.get("image_path")
|
||||
.and_then(|p| p.as_str())
|
||||
.ok_or("Missing image_path parameter")?;
|
||||
|
||||
let format = args
|
||||
.get("format")
|
||||
.and_then(|f| f.as_str())
|
||||
.unwrap_or("latex");
|
||||
|
||||
// Check if file exists
|
||||
if !std::path::Path::new(image_path).exists() {
|
||||
return Err(format!("Image file not found: {}", image_path));
|
||||
}
|
||||
|
||||
// Load and process image
|
||||
let img = image::open(image_path).map_err(|e| format!("Failed to load image: {}", e))?;
|
||||
|
||||
// Perform OCR (using mock for now, real inference when models are available)
|
||||
let result = self.perform_ocr(&img, format).await?;
|
||||
|
||||
Ok(serde_json::to_string_pretty(&json!({
|
||||
"file": image_path,
|
||||
"format": format,
|
||||
"result": result,
|
||||
"confidence": 0.95
|
||||
}))
|
||||
.unwrap_or_default())
|
||||
}
|
||||
|
||||
/// OCR base64 image
|
||||
async fn call_ocr_base64(&self, args: &Value) -> Result<String, String> {
|
||||
let image_data = args
|
||||
.get("image_data")
|
||||
.and_then(|d| d.as_str())
|
||||
.ok_or("Missing image_data parameter")?;
|
||||
|
||||
let format = args
|
||||
.get("format")
|
||||
.and_then(|f| f.as_str())
|
||||
.unwrap_or("latex");
|
||||
|
||||
// Decode base64
|
||||
let decoded =
|
||||
base64::Engine::decode(&base64::engine::general_purpose::STANDARD, image_data)
|
||||
.map_err(|e| format!("Invalid base64 data: {}", e))?;
|
||||
|
||||
// Load image from bytes
|
||||
let img = image::load_from_memory(&decoded)
|
||||
.map_err(|e| format!("Failed to load image from data: {}", e))?;
|
||||
|
||||
// Perform OCR
|
||||
let result = self.perform_ocr(&img, format).await?;
|
||||
|
||||
Ok(serde_json::to_string_pretty(&json!({
|
||||
"format": format,
|
||||
"result": result,
|
||||
"confidence": 0.95
|
||||
}))
|
||||
.unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Batch OCR processing
|
||||
async fn call_batch_ocr(&self, args: &Value) -> Result<String, String> {
|
||||
let directory = args
|
||||
.get("directory")
|
||||
.and_then(|d| d.as_str())
|
||||
.ok_or("Missing directory parameter")?;
|
||||
|
||||
let pattern = args
|
||||
.get("pattern")
|
||||
.and_then(|p| p.as_str())
|
||||
.unwrap_or("*.png");
|
||||
|
||||
let format = args
|
||||
.get("format")
|
||||
.and_then(|f| f.as_str())
|
||||
.unwrap_or("json");
|
||||
|
||||
// Find files matching pattern
|
||||
let glob_pattern = format!("{}/{}", directory, pattern);
|
||||
let paths: Vec<_> = glob::glob(&glob_pattern)
|
||||
.map_err(|e| format!("Invalid glob pattern: {}", e))?
|
||||
.filter_map(|p| p.ok())
|
||||
.collect();
|
||||
|
||||
let mut results = Vec::new();
|
||||
for path in &paths {
|
||||
let img = match image::open(path) {
|
||||
Ok(img) => img,
|
||||
Err(e) => {
|
||||
results.push(json!({
|
||||
"file": path.display().to_string(),
|
||||
"error": e.to_string()
|
||||
}));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let ocr_result = self.perform_ocr(&img, format).await.unwrap_or_else(|e| e);
|
||||
results.push(json!({
|
||||
"file": path.display().to_string(),
|
||||
"result": ocr_result,
|
||||
"confidence": 0.95
|
||||
}));
|
||||
}
|
||||
|
||||
Ok(serde_json::to_string_pretty(&json!({
|
||||
"total": paths.len(),
|
||||
"processed": results.len(),
|
||||
"results": results
|
||||
}))
|
||||
.unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Preprocess image
|
||||
async fn call_preprocess_image(&self, args: &Value) -> Result<String, String> {
|
||||
let image_path = args
|
||||
.get("image_path")
|
||||
.and_then(|p| p.as_str())
|
||||
.ok_or("Missing image_path parameter")?;
|
||||
|
||||
let output_path = args
|
||||
.get("output_path")
|
||||
.and_then(|p| p.as_str())
|
||||
.ok_or("Missing output_path parameter")?;
|
||||
|
||||
let operations: Vec<&str> = args
|
||||
.get("operations")
|
||||
.and_then(|o| o.as_array())
|
||||
.map(|arr| arr.iter().filter_map(|v| v.as_str()).collect())
|
||||
.unwrap_or_else(|| vec!["grayscale", "resize"]);
|
||||
|
||||
// Load image
|
||||
let mut img =
|
||||
image::open(image_path).map_err(|e| format!("Failed to load image: {}", e))?;
|
||||
|
||||
// Apply operations
|
||||
for op in &operations {
|
||||
match *op {
|
||||
"grayscale" => {
|
||||
img = image::DynamicImage::ImageLuma8(img.to_luma8());
|
||||
}
|
||||
"resize" => {
|
||||
let width = args
|
||||
.get("target_width")
|
||||
.and_then(|w| w.as_u64())
|
||||
.unwrap_or(640) as u32;
|
||||
let height = args
|
||||
.get("target_height")
|
||||
.and_then(|h| h.as_u64())
|
||||
.unwrap_or(480) as u32;
|
||||
img = img.resize(width, height, image::imageops::FilterType::Lanczos3);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Save output
|
||||
img.save(output_path)
|
||||
.map_err(|e| format!("Failed to save image: {}", e))?;
|
||||
|
||||
Ok(serde_json::to_string_pretty(&json!({
|
||||
"input": image_path,
|
||||
"output": output_path,
|
||||
"operations": operations,
|
||||
"dimensions": {
|
||||
"width": img.width(),
|
||||
"height": img.height()
|
||||
}
|
||||
}))
|
||||
.unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Convert LaTeX to MathML
|
||||
async fn call_latex_to_mathml(&self, args: &Value) -> Result<String, String> {
|
||||
let latex = args
|
||||
.get("latex")
|
||||
.and_then(|l| l.as_str())
|
||||
.ok_or("Missing latex parameter")?;
|
||||
|
||||
// Simple LaTeX to MathML conversion (placeholder)
|
||||
let mathml = format!(
|
||||
r#"<math xmlns="http://www.w3.org/1998/Math/MathML"><mrow><mi>{}</mi></mrow></math>"#,
|
||||
latex.replace("\\", "").replace("{", "").replace("}", "")
|
||||
);
|
||||
|
||||
Ok(serde_json::to_string_pretty(&json!({
|
||||
"latex": latex,
|
||||
"mathml": mathml
|
||||
}))
|
||||
.unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Run performance benchmark
|
||||
async fn call_benchmark(&self, args: &Value) -> Result<String, String> {
|
||||
let iterations = args
|
||||
.get("iterations")
|
||||
.and_then(|i| i.as_u64())
|
||||
.unwrap_or(10) as usize;
|
||||
|
||||
use std::time::Instant;
|
||||
|
||||
// Generate test image
|
||||
let test_img =
|
||||
image::DynamicImage::ImageRgb8(image::ImageBuffer::from_fn(400, 100, |_, _| {
|
||||
image::Rgb([255u8, 255u8, 255u8])
|
||||
}));
|
||||
|
||||
// Benchmark preprocessing
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
let _gray = test_img.to_luma8();
|
||||
}
|
||||
let grayscale_time = start.elapsed() / iterations as u32;
|
||||
|
||||
let start = Instant::now();
|
||||
for _ in 0..iterations {
|
||||
let _resized = test_img.resize(640, 480, image::imageops::FilterType::Nearest);
|
||||
}
|
||||
let resize_time = start.elapsed() / iterations as u32;
|
||||
|
||||
Ok(serde_json::to_string_pretty(&json!({
|
||||
"iterations": iterations,
|
||||
"benchmarks": {
|
||||
"grayscale_avg_ms": grayscale_time.as_secs_f64() * 1000.0,
|
||||
"resize_avg_ms": resize_time.as_secs_f64() * 1000.0,
|
||||
},
|
||||
"system": {
|
||||
"cpu_cores": num_cpus::get()
|
||||
}
|
||||
}))
|
||||
.unwrap_or_default())
|
||||
}
|
||||
|
||||
/// Perform OCR on image (placeholder implementation)
|
||||
async fn perform_ocr(
|
||||
&self,
|
||||
_img: &image::DynamicImage,
|
||||
format: &str,
|
||||
) -> Result<String, String> {
|
||||
// This is a placeholder - in production, this would call the actual OCR engine
|
||||
let result = match format {
|
||||
"latex" => r"\int_0^1 x^2 \, dx = \frac{1}{3}".to_string(),
|
||||
"text" => "Sample OCR extracted text".to_string(),
|
||||
"mathml" => r#"<math><mrow><mi>x</mi><mo>=</mo><mn>2</mn></mrow></math>"#.to_string(),
|
||||
"asciimath" => "int_0^1 x^2 dx = 1/3".to_string(),
|
||||
_ => "Unknown format".to_string(),
|
||||
};
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the MCP server
|
||||
pub async fn run(args: McpArgs) -> anyhow::Result<()> {
|
||||
let server = McpServer::new(&args);
|
||||
|
||||
if args.debug {
|
||||
eprintln!("[MCP] SciPix MCP Server starting...");
|
||||
eprintln!("[MCP] Version: {}", env!("CARGO_PKG_VERSION"));
|
||||
}
|
||||
|
||||
let stdin = io::stdin();
|
||||
let mut stdout = io::stdout();
|
||||
|
||||
for line in stdin.lock().lines() {
|
||||
let line = match line {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
if args.debug {
|
||||
eprintln!("[MCP ERROR] Failed to read stdin: {}", e);
|
||||
}
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
if line.trim().is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
if args.debug {
|
||||
eprintln!("[MCP DEBUG] Received: {}", line);
|
||||
}
|
||||
|
||||
let request: JsonRpcRequest = match serde_json::from_str(&line) {
|
||||
Ok(req) => req,
|
||||
Err(e) => {
|
||||
let error_response =
|
||||
JsonRpcResponse::error(Value::Null, -32700, &format!("Parse error: {}", e));
|
||||
let output = serde_json::to_string(&error_response).unwrap_or_default();
|
||||
writeln!(stdout, "{}", output)?;
|
||||
stdout.flush()?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let response = server.handle_request(request).await;
|
||||
let output = serde_json::to_string(&response)?;
|
||||
|
||||
if args.debug {
|
||||
eprintln!("[MCP DEBUG] Response: {}", output);
|
||||
}
|
||||
|
||||
writeln!(stdout, "{}", output)?;
|
||||
stdout.flush()?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
99
examples/scipix/src/cli/commands/mod.rs
Normal file
99
examples/scipix/src/cli/commands/mod.rs
Normal file
@@ -0,0 +1,99 @@
|
||||
pub mod batch;
|
||||
pub mod config;
|
||||
pub mod doctor;
|
||||
pub mod mcp;
|
||||
pub mod ocr;
|
||||
pub mod serve;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Common result structure for OCR operations
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OcrResult {
|
||||
/// Source file path
|
||||
pub file: PathBuf,
|
||||
|
||||
/// Extracted text content
|
||||
pub text: String,
|
||||
|
||||
/// LaTeX representation (if available)
|
||||
pub latex: Option<String>,
|
||||
|
||||
/// Confidence score (0.0 to 1.0)
|
||||
pub confidence: f64,
|
||||
|
||||
/// Processing time in milliseconds
|
||||
pub processing_time_ms: u64,
|
||||
|
||||
/// Any errors or warnings
|
||||
pub errors: Vec<String>,
|
||||
}
|
||||
|
||||
impl OcrResult {
|
||||
/// Create a new OCR result
|
||||
pub fn new(file: PathBuf, text: String, confidence: f64) -> Self {
|
||||
Self {
|
||||
file,
|
||||
text,
|
||||
latex: None,
|
||||
confidence,
|
||||
processing_time_ms: 0,
|
||||
errors: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set LaTeX content
|
||||
pub fn with_latex(mut self, latex: String) -> Self {
|
||||
self.latex = Some(latex);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set processing time
|
||||
pub fn with_processing_time(mut self, time_ms: u64) -> Self {
|
||||
self.processing_time_ms = time_ms;
|
||||
self
|
||||
}
|
||||
|
||||
/// Add an error message
|
||||
pub fn add_error(&mut self, error: String) {
|
||||
self.errors.push(error);
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for OCR processing
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OcrConfig {
|
||||
/// Minimum confidence threshold
|
||||
pub min_confidence: f64,
|
||||
|
||||
/// Maximum image size in bytes
|
||||
pub max_image_size: usize,
|
||||
|
||||
/// Supported file extensions
|
||||
pub supported_extensions: Vec<String>,
|
||||
|
||||
/// API endpoint (if using remote service)
|
||||
pub api_endpoint: Option<String>,
|
||||
|
||||
/// API key (if using remote service)
|
||||
pub api_key: Option<String>,
|
||||
}
|
||||
|
||||
impl Default for OcrConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
min_confidence: 0.7,
|
||||
max_image_size: 10 * 1024 * 1024, // 10MB
|
||||
supported_extensions: vec![
|
||||
"png".to_string(),
|
||||
"jpg".to_string(),
|
||||
"jpeg".to_string(),
|
||||
"pdf".to_string(),
|
||||
"gif".to_string(),
|
||||
],
|
||||
api_endpoint: None,
|
||||
api_key: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
210
examples/scipix/src/cli/commands/ocr.rs
Normal file
210
examples/scipix/src/cli/commands/ocr.rs
Normal file
@@ -0,0 +1,210 @@
|
||||
use anyhow::{Context, Result};
|
||||
use clap::Args;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Instant;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use super::{OcrConfig, OcrResult};
|
||||
use crate::cli::{output, Cli, OutputFormat};
|
||||
|
||||
/// Process a single image or file with OCR
|
||||
#[derive(Args, Debug, Clone)]
|
||||
pub struct OcrArgs {
|
||||
/// Path to the image file to process
|
||||
#[arg(value_name = "FILE", help = "Path to the image file")]
|
||||
pub file: PathBuf,
|
||||
|
||||
/// Minimum confidence threshold (0.0 to 1.0)
|
||||
#[arg(
|
||||
short = 't',
|
||||
long,
|
||||
default_value = "0.7",
|
||||
help = "Minimum confidence threshold for results"
|
||||
)]
|
||||
pub threshold: f64,
|
||||
|
||||
/// Save output to file instead of stdout
|
||||
#[arg(
|
||||
short,
|
||||
long,
|
||||
value_name = "OUTPUT",
|
||||
help = "Save output to file instead of stdout"
|
||||
)]
|
||||
pub output: Option<PathBuf>,
|
||||
|
||||
/// Pretty-print JSON output
|
||||
#[arg(
|
||||
short,
|
||||
long,
|
||||
help = "Pretty-print JSON output (only with --format json)"
|
||||
)]
|
||||
pub pretty: bool,
|
||||
|
||||
/// Include metadata in output
|
||||
#[arg(short, long, help = "Include processing metadata in output")]
|
||||
pub metadata: bool,
|
||||
|
||||
/// Force processing even if confidence is below threshold
|
||||
#[arg(
|
||||
short = 'f',
|
||||
long,
|
||||
help = "Force processing even if confidence is below threshold"
|
||||
)]
|
||||
pub force: bool,
|
||||
}
|
||||
|
||||
pub async fn execute(args: OcrArgs, cli: &Cli) -> Result<()> {
|
||||
info!("Processing file: {}", args.file.display());
|
||||
|
||||
// Validate input file
|
||||
if !args.file.exists() {
|
||||
anyhow::bail!("File not found: {}", args.file.display());
|
||||
}
|
||||
|
||||
if !args.file.is_file() {
|
||||
anyhow::bail!("Not a file: {}", args.file.display());
|
||||
}
|
||||
|
||||
// Load configuration
|
||||
let config = load_config(cli.config.as_ref())?;
|
||||
|
||||
// Validate file extension
|
||||
if let Some(ext) = args.file.extension() {
|
||||
let ext_str = ext.to_string_lossy().to_lowercase();
|
||||
if !config.supported_extensions.contains(&ext_str) {
|
||||
anyhow::bail!(
|
||||
"Unsupported file extension: {}. Supported: {}",
|
||||
ext_str,
|
||||
config.supported_extensions.join(", ")
|
||||
);
|
||||
}
|
||||
} else {
|
||||
anyhow::bail!("File has no extension");
|
||||
}
|
||||
|
||||
// Check file size
|
||||
let metadata = std::fs::metadata(&args.file).context("Failed to read file metadata")?;
|
||||
|
||||
if metadata.len() as usize > config.max_image_size {
|
||||
anyhow::bail!(
|
||||
"File too large: {} bytes (max: {} bytes)",
|
||||
metadata.len(),
|
||||
config.max_image_size
|
||||
);
|
||||
}
|
||||
|
||||
// Process the file
|
||||
let start = Instant::now();
|
||||
let result = process_file(&args.file, &config).await?;
|
||||
let processing_time = start.elapsed();
|
||||
|
||||
debug!("Processing completed in {:?}", processing_time);
|
||||
|
||||
// Check confidence threshold
|
||||
if result.confidence < args.threshold && !args.force {
|
||||
anyhow::bail!(
|
||||
"Confidence {} is below threshold {} (use --force to override)",
|
||||
result.confidence,
|
||||
args.threshold
|
||||
);
|
||||
}
|
||||
|
||||
// Format and output result
|
||||
let output_content = format_result(&result, &cli.format, args.pretty, args.metadata)?;
|
||||
|
||||
if let Some(output_path) = &args.output {
|
||||
std::fs::write(output_path, &output_content).context("Failed to write output file")?;
|
||||
info!("Output saved to: {}", output_path.display());
|
||||
} else {
|
||||
println!("{}", output_content);
|
||||
}
|
||||
|
||||
// Display summary if not quiet
|
||||
if !cli.quiet {
|
||||
output::print_ocr_summary(&result);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn process_file(file: &PathBuf, _config: &OcrConfig) -> Result<OcrResult> {
|
||||
// TODO: Implement actual OCR processing
|
||||
// For now, return a mock result
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
// Simulate OCR processing
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
|
||||
let processing_time = start.elapsed().as_millis() as u64;
|
||||
|
||||
Ok(OcrResult {
|
||||
file: file.clone(),
|
||||
text: "Sample OCR text from image".to_string(),
|
||||
latex: Some(r"\int_0^1 x^2 \, dx = \frac{1}{3}".to_string()),
|
||||
confidence: 0.95,
|
||||
processing_time_ms: processing_time,
|
||||
errors: Vec::new(),
|
||||
})
|
||||
}
|
||||
|
||||
fn format_result(
|
||||
result: &OcrResult,
|
||||
format: &OutputFormat,
|
||||
pretty: bool,
|
||||
include_metadata: bool,
|
||||
) -> Result<String> {
|
||||
match format {
|
||||
OutputFormat::Json => if include_metadata {
|
||||
if pretty {
|
||||
serde_json::to_string_pretty(result)
|
||||
} else {
|
||||
serde_json::to_string(result)
|
||||
}
|
||||
} else {
|
||||
let simple = serde_json::json!({
|
||||
"text": result.text,
|
||||
"latex": result.latex,
|
||||
"confidence": result.confidence,
|
||||
});
|
||||
if pretty {
|
||||
serde_json::to_string_pretty(&simple)
|
||||
} else {
|
||||
serde_json::to_string(&simple)
|
||||
}
|
||||
}
|
||||
.context("Failed to serialize to JSON"),
|
||||
OutputFormat::Text => Ok(result.text.clone()),
|
||||
OutputFormat::Latex => Ok(result.latex.clone().unwrap_or_else(|| result.text.clone())),
|
||||
OutputFormat::Markdown => {
|
||||
let mut md = format!("# OCR Result\n\n{}\n", result.text);
|
||||
if let Some(latex) = &result.latex {
|
||||
md.push_str(&format!("\n## LaTeX\n\n```latex\n{}\n```\n", latex));
|
||||
}
|
||||
if include_metadata {
|
||||
md.push_str(&format!(
|
||||
"\n---\n\nConfidence: {:.2}%\nProcessing time: {}ms\n",
|
||||
result.confidence * 100.0,
|
||||
result.processing_time_ms
|
||||
));
|
||||
}
|
||||
Ok(md)
|
||||
}
|
||||
OutputFormat::MathMl => {
|
||||
// TODO: Implement MathML conversion
|
||||
Ok(format!(
|
||||
"<math xmlns=\"http://www.w3.org/1998/Math/MathML\">\n {}\n</math>",
|
||||
result.text
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn load_config(config_path: Option<&PathBuf>) -> Result<OcrConfig> {
|
||||
if let Some(path) = config_path {
|
||||
let content = std::fs::read_to_string(path).context("Failed to read config file")?;
|
||||
toml::from_str(&content).context("Failed to parse config file")
|
||||
} else {
|
||||
Ok(OcrConfig::default())
|
||||
}
|
||||
}
|
||||
293
examples/scipix/src/cli/commands/serve.rs
Normal file
293
examples/scipix/src/cli/commands/serve.rs
Normal file
@@ -0,0 +1,293 @@
|
||||
use anyhow::{Context, Result};
|
||||
use axum::{
|
||||
extract::{Multipart, State},
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
routing::{get, post},
|
||||
Json, Router,
|
||||
};
|
||||
use clap::Args;
|
||||
use std::net::SocketAddr;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use tokio::signal;
|
||||
use tower_http::{cors::CorsLayer, trace::TraceLayer};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use super::{OcrConfig, OcrResult};
|
||||
use crate::cli::Cli;
|
||||
|
||||
/// Start the API server
|
||||
#[derive(Args, Debug, Clone)]
|
||||
pub struct ServeArgs {
|
||||
/// Port to listen on
|
||||
#[arg(
|
||||
short,
|
||||
long,
|
||||
default_value = "8080",
|
||||
env = "MATHPIX_PORT",
|
||||
help = "Port to listen on"
|
||||
)]
|
||||
pub port: u16,
|
||||
|
||||
/// Host to bind to
|
||||
#[arg(
|
||||
short = 'H',
|
||||
long,
|
||||
default_value = "127.0.0.1",
|
||||
env = "MATHPIX_HOST",
|
||||
help = "Host address to bind to"
|
||||
)]
|
||||
pub host: String,
|
||||
|
||||
/// Directory containing ML models
|
||||
#[arg(
|
||||
long,
|
||||
value_name = "DIR",
|
||||
help = "Directory containing ML models to preload"
|
||||
)]
|
||||
pub model_dir: Option<PathBuf>,
|
||||
|
||||
/// Enable CORS
|
||||
#[arg(long, help = "Enable CORS for cross-origin requests")]
|
||||
pub cors: bool,
|
||||
|
||||
/// Maximum request size in MB
|
||||
#[arg(long, default_value = "10", help = "Maximum request size in megabytes")]
|
||||
pub max_size: usize,
|
||||
|
||||
/// Number of worker threads
|
||||
#[arg(
|
||||
short = 'w',
|
||||
long,
|
||||
default_value = "4",
|
||||
help = "Number of worker threads"
|
||||
)]
|
||||
pub workers: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AppState {
|
||||
config: Arc<OcrConfig>,
|
||||
max_size: usize,
|
||||
}
|
||||
|
||||
pub async fn execute(args: ServeArgs, cli: &Cli) -> Result<()> {
|
||||
info!("Starting Scipix API server");
|
||||
|
||||
// Load configuration
|
||||
let config = Arc::new(load_config(cli.config.as_ref())?);
|
||||
|
||||
// Preload models if specified
|
||||
if let Some(model_dir) = &args.model_dir {
|
||||
info!("Preloading models from: {}", model_dir.display());
|
||||
preload_models(model_dir)?;
|
||||
}
|
||||
|
||||
// Create app state
|
||||
let state = AppState {
|
||||
config,
|
||||
max_size: args.max_size * 1024 * 1024,
|
||||
};
|
||||
|
||||
// Build router
|
||||
let mut app = Router::new()
|
||||
.route("/", get(root))
|
||||
.route("/health", get(health))
|
||||
.route("/api/v1/ocr", post(ocr_handler))
|
||||
.route("/api/v1/batch", post(batch_handler))
|
||||
.with_state(state)
|
||||
.layer(TraceLayer::new_for_http());
|
||||
|
||||
// Add CORS if enabled
|
||||
if args.cors {
|
||||
app = app.layer(CorsLayer::permissive());
|
||||
info!("CORS enabled");
|
||||
}
|
||||
|
||||
// Create socket address
|
||||
let addr: SocketAddr = format!("{}:{}", args.host, args.port)
|
||||
.parse()
|
||||
.context("Invalid host/port combination")?;
|
||||
|
||||
info!("Server listening on http://{}", addr);
|
||||
info!("API endpoints:");
|
||||
info!(" POST http://{}/api/v1/ocr - Single file OCR", addr);
|
||||
info!(" POST http://{}/api/v1/batch - Batch processing", addr);
|
||||
info!(" GET http://{}/health - Health check", addr);
|
||||
|
||||
// Create server
|
||||
let listener = tokio::net::TcpListener::bind(addr)
|
||||
.await
|
||||
.context("Failed to bind to address")?;
|
||||
|
||||
// Run server with graceful shutdown
|
||||
axum::serve(listener, app)
|
||||
.with_graceful_shutdown(shutdown_signal())
|
||||
.await
|
||||
.context("Server error")?;
|
||||
|
||||
info!("Server shutdown complete");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn root() -> &'static str {
|
||||
"Scipix OCR API Server\n\nEndpoints:\n POST /api/v1/ocr - Single file OCR\n POST /api/v1/batch - Batch processing\n GET /health - Health check"
|
||||
}
|
||||
|
||||
async fn health() -> impl IntoResponse {
|
||||
Json(serde_json::json!({
|
||||
"status": "healthy",
|
||||
"version": env!("CARGO_PKG_VERSION"),
|
||||
}))
|
||||
}
|
||||
|
||||
async fn ocr_handler(
|
||||
State(state): State<AppState>,
|
||||
mut multipart: Multipart,
|
||||
) -> Result<Json<OcrResult>, (StatusCode, String)> {
|
||||
while let Some(field) = multipart
|
||||
.next_field()
|
||||
.await
|
||||
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?
|
||||
{
|
||||
let name = field.name().unwrap_or("").to_string();
|
||||
|
||||
if name == "file" {
|
||||
let data = field
|
||||
.bytes()
|
||||
.await
|
||||
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
|
||||
|
||||
if data.len() > state.max_size {
|
||||
return Err((
|
||||
StatusCode::PAYLOAD_TOO_LARGE,
|
||||
format!(
|
||||
"File too large: {} bytes (max: {} bytes)",
|
||||
data.len(),
|
||||
state.max_size
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// Process the file
|
||||
let result = process_image_data(&data, &state.config)
|
||||
.await
|
||||
.map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?;
|
||||
|
||||
return Ok(Json(result));
|
||||
}
|
||||
}
|
||||
|
||||
Err((StatusCode::BAD_REQUEST, "No file provided".to_string()))
|
||||
}
|
||||
|
||||
async fn batch_handler(
|
||||
State(state): State<AppState>,
|
||||
mut multipart: Multipart,
|
||||
) -> Result<Json<Vec<OcrResult>>, (StatusCode, String)> {
|
||||
let mut results = Vec::new();
|
||||
|
||||
while let Some(field) = multipart
|
||||
.next_field()
|
||||
.await
|
||||
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?
|
||||
{
|
||||
let name = field.name().unwrap_or("").to_string();
|
||||
|
||||
if name == "files" {
|
||||
let data = field
|
||||
.bytes()
|
||||
.await
|
||||
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
|
||||
|
||||
if data.len() > state.max_size {
|
||||
warn!("Skipping file: too large ({} bytes)", data.len());
|
||||
continue;
|
||||
}
|
||||
|
||||
// Process the file
|
||||
match process_image_data(&data, &state.config).await {
|
||||
Ok(result) => results.push(result),
|
||||
Err(e) => warn!("Failed to process file: {}", e),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if results.is_empty() {
|
||||
return Err((
|
||||
StatusCode::BAD_REQUEST,
|
||||
"No valid files processed".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(Json(results))
|
||||
}
|
||||
|
||||
async fn process_image_data(data: &[u8], _config: &OcrConfig) -> Result<OcrResult> {
|
||||
// TODO: Implement actual OCR processing
|
||||
// For now, return a mock result
|
||||
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||
|
||||
Ok(OcrResult {
|
||||
file: PathBuf::from("uploaded_file"),
|
||||
text: format!("OCR text from uploaded image ({} bytes)", data.len()),
|
||||
latex: Some(r"\text{Sample LaTeX}".to_string()),
|
||||
confidence: 0.92,
|
||||
processing_time_ms: 50,
|
||||
errors: Vec::new(),
|
||||
})
|
||||
}
|
||||
|
||||
fn preload_models(model_dir: &PathBuf) -> Result<()> {
|
||||
if !model_dir.exists() {
|
||||
anyhow::bail!("Model directory not found: {}", model_dir.display());
|
||||
}
|
||||
|
||||
if !model_dir.is_dir() {
|
||||
anyhow::bail!("Not a directory: {}", model_dir.display());
|
||||
}
|
||||
|
||||
// TODO: Implement model preloading
|
||||
info!("Models preloaded from {}", model_dir.display());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_config(config_path: Option<&PathBuf>) -> Result<OcrConfig> {
|
||||
if let Some(path) = config_path {
|
||||
let content = std::fs::read_to_string(path).context("Failed to read config file")?;
|
||||
toml::from_str(&content).context("Failed to parse config file")
|
||||
} else {
|
||||
Ok(OcrConfig::default())
|
||||
}
|
||||
}
|
||||
|
||||
async fn shutdown_signal() {
|
||||
let ctrl_c = async {
|
||||
signal::ctrl_c()
|
||||
.await
|
||||
.expect("failed to install Ctrl+C handler");
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
let terminate = async {
|
||||
signal::unix::signal(signal::unix::SignalKind::terminate())
|
||||
.expect("failed to install signal handler")
|
||||
.recv()
|
||||
.await;
|
||||
};
|
||||
|
||||
#[cfg(not(unix))]
|
||||
let terminate = std::future::pending::<()>();
|
||||
|
||||
tokio::select! {
|
||||
_ = ctrl_c => {
|
||||
info!("Received Ctrl+C signal");
|
||||
},
|
||||
_ = terminate => {
|
||||
info!("Received terminate signal");
|
||||
},
|
||||
}
|
||||
}
|
||||
115
examples/scipix/src/cli/mod.rs
Normal file
115
examples/scipix/src/cli/mod.rs
Normal file
@@ -0,0 +1,115 @@
|
||||
pub mod commands;
|
||||
pub mod output;
|
||||
|
||||
use clap::{Parser, Subcommand};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Scipix CLI - OCR and mathematical content processing
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(
|
||||
name = "scipix-cli",
|
||||
version,
|
||||
about = "A Rust-based CLI for Scipix OCR processing",
|
||||
long_about = "Process images with OCR, extract mathematical formulas, and convert to LaTeX or other formats.\n\n\
|
||||
Supports single file processing, batch operations, and API server mode."
|
||||
)]
|
||||
pub struct Cli {
|
||||
/// Path to configuration file
|
||||
#[arg(
|
||||
short,
|
||||
long,
|
||||
global = true,
|
||||
env = "MATHPIX_CONFIG",
|
||||
help = "Path to configuration file"
|
||||
)]
|
||||
pub config: Option<PathBuf>,
|
||||
|
||||
/// Enable verbose logging
|
||||
#[arg(
|
||||
short,
|
||||
long,
|
||||
global = true,
|
||||
help = "Enable verbose logging (DEBUG level)"
|
||||
)]
|
||||
pub verbose: bool,
|
||||
|
||||
/// Suppress all non-error output
|
||||
#[arg(
|
||||
short,
|
||||
long,
|
||||
global = true,
|
||||
conflicts_with = "verbose",
|
||||
help = "Suppress all non-error output"
|
||||
)]
|
||||
pub quiet: bool,
|
||||
|
||||
/// Output format (json, text, latex, markdown)
|
||||
#[arg(
|
||||
short,
|
||||
long,
|
||||
global = true,
|
||||
default_value = "text",
|
||||
help = "Output format for results"
|
||||
)]
|
||||
pub format: OutputFormat,
|
||||
|
||||
#[command(subcommand)]
|
||||
pub command: Commands,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
pub enum Commands {
|
||||
/// Process a single image or file with OCR
|
||||
Ocr(commands::ocr::OcrArgs),
|
||||
|
||||
/// Process multiple files in batch mode
|
||||
Batch(commands::batch::BatchArgs),
|
||||
|
||||
/// Start the API server
|
||||
Serve(commands::serve::ServeArgs),
|
||||
|
||||
/// Start the MCP (Model Context Protocol) server for AI integration
|
||||
Mcp(commands::mcp::McpArgs),
|
||||
|
||||
/// Manage configuration
|
||||
Config(commands::config::ConfigArgs),
|
||||
|
||||
/// Diagnose environment and optimize configuration
|
||||
Doctor(commands::doctor::DoctorArgs),
|
||||
|
||||
/// Show version information
|
||||
Version,
|
||||
|
||||
/// Generate shell completions
|
||||
Completions {
|
||||
/// Shell to generate completions for (bash, zsh, fish, powershell)
|
||||
#[arg(value_enum)]
|
||||
shell: Option<clap_complete::Shell>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, clap::ValueEnum)]
|
||||
pub enum OutputFormat {
|
||||
/// Plain text output
|
||||
Text,
|
||||
/// JSON output
|
||||
Json,
|
||||
/// LaTeX format
|
||||
Latex,
|
||||
/// Markdown format
|
||||
Markdown,
|
||||
/// MathML format
|
||||
MathMl,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for OutputFormat {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
OutputFormat::Text => write!(f, "text"),
|
||||
OutputFormat::Json => write!(f, "json"),
|
||||
OutputFormat::Latex => write!(f, "latex"),
|
||||
OutputFormat::Markdown => write!(f, "markdown"),
|
||||
OutputFormat::MathMl => write!(f, "mathml"),
|
||||
}
|
||||
}
|
||||
}
|
||||
223
examples/scipix/src/cli/output.rs
Normal file
223
examples/scipix/src/cli/output.rs
Normal file
@@ -0,0 +1,223 @@
|
||||
use comfy_table::{modifiers::UTF8_ROUND_CORNERS, presets::UTF8_FULL, Cell, Color, Table};
|
||||
use console::style;
|
||||
|
||||
use super::commands::OcrResult;
|
||||
|
||||
/// Print a summary of a single OCR result
|
||||
pub fn print_ocr_summary(result: &OcrResult) {
|
||||
println!("\n{}", style("OCR Processing Summary").bold().cyan());
|
||||
println!("{}", style("─".repeat(60)).dim());
|
||||
|
||||
let mut table = Table::new();
|
||||
table
|
||||
.load_preset(UTF8_FULL)
|
||||
.apply_modifier(UTF8_ROUND_CORNERS)
|
||||
.set_header(vec![
|
||||
Cell::new("Property").fg(Color::Cyan),
|
||||
Cell::new("Value").fg(Color::Green),
|
||||
]);
|
||||
|
||||
table.add_row(vec![
|
||||
Cell::new("File"),
|
||||
Cell::new(result.file.display().to_string()),
|
||||
]);
|
||||
|
||||
table.add_row(vec![
|
||||
Cell::new("Confidence"),
|
||||
Cell::new(format!("{:.2}%", result.confidence * 100.0))
|
||||
.fg(confidence_color(result.confidence)),
|
||||
]);
|
||||
|
||||
table.add_row(vec![
|
||||
Cell::new("Processing Time"),
|
||||
Cell::new(format!("{}ms", result.processing_time_ms)),
|
||||
]);
|
||||
|
||||
if let Some(latex) = &result.latex {
|
||||
table.add_row(vec![
|
||||
Cell::new("LaTeX"),
|
||||
Cell::new(if latex.len() > 50 {
|
||||
format!("{}...", &latex[..50])
|
||||
} else {
|
||||
latex.clone()
|
||||
}),
|
||||
]);
|
||||
}
|
||||
|
||||
if !result.errors.is_empty() {
|
||||
table.add_row(vec![
|
||||
Cell::new("Errors").fg(Color::Red),
|
||||
Cell::new(result.errors.len().to_string()).fg(Color::Red),
|
||||
]);
|
||||
}
|
||||
|
||||
println!("{table}");
|
||||
|
||||
if !result.errors.is_empty() {
|
||||
println!("\n{}", style("Errors:").bold().red());
|
||||
for (i, error) in result.errors.iter().enumerate() {
|
||||
println!(" {}. {}", i + 1, style(error).red());
|
||||
}
|
||||
}
|
||||
|
||||
println!();
|
||||
}
|
||||
|
||||
/// Print a summary of batch processing results
|
||||
pub fn print_batch_summary(passed: &[OcrResult], failed: &[OcrResult], threshold: f64) {
|
||||
println!("\n{}", style("Batch Processing Summary").bold().cyan());
|
||||
println!("{}", style("═".repeat(60)).dim());
|
||||
|
||||
let total = passed.len() + failed.len();
|
||||
let avg_confidence = if !passed.is_empty() {
|
||||
passed.iter().map(|r| r.confidence).sum::<f64>() / passed.len() as f64
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
let total_time: u64 = passed.iter().map(|r| r.processing_time_ms).sum();
|
||||
let avg_time = if !passed.is_empty() {
|
||||
total_time / passed.len() as u64
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
let mut table = Table::new();
|
||||
table
|
||||
.load_preset(UTF8_FULL)
|
||||
.apply_modifier(UTF8_ROUND_CORNERS)
|
||||
.set_header(vec![
|
||||
Cell::new("Metric").fg(Color::Cyan),
|
||||
Cell::new("Value").fg(Color::Green),
|
||||
]);
|
||||
|
||||
table.add_row(vec![Cell::new("Total Files"), Cell::new(total.to_string())]);
|
||||
|
||||
table.add_row(vec![
|
||||
Cell::new("Passed").fg(Color::Green),
|
||||
Cell::new(format!(
|
||||
"{} ({:.1}%)",
|
||||
passed.len(),
|
||||
(passed.len() as f64 / total as f64) * 100.0
|
||||
))
|
||||
.fg(Color::Green),
|
||||
]);
|
||||
|
||||
table.add_row(vec![
|
||||
Cell::new("Failed").fg(Color::Red),
|
||||
Cell::new(format!(
|
||||
"{} ({:.1}%)",
|
||||
failed.len(),
|
||||
(failed.len() as f64 / total as f64) * 100.0
|
||||
))
|
||||
.fg(if failed.is_empty() {
|
||||
Color::Green
|
||||
} else {
|
||||
Color::Red
|
||||
}),
|
||||
]);
|
||||
|
||||
table.add_row(vec![
|
||||
Cell::new("Threshold"),
|
||||
Cell::new(format!("{:.2}%", threshold * 100.0)),
|
||||
]);
|
||||
|
||||
table.add_row(vec![
|
||||
Cell::new("Avg Confidence"),
|
||||
Cell::new(format!("{:.2}%", avg_confidence * 100.0)).fg(confidence_color(avg_confidence)),
|
||||
]);
|
||||
|
||||
table.add_row(vec![
|
||||
Cell::new("Avg Processing Time"),
|
||||
Cell::new(format!("{}ms", avg_time)),
|
||||
]);
|
||||
|
||||
table.add_row(vec![
|
||||
Cell::new("Total Processing Time"),
|
||||
Cell::new(format!("{:.2}s", total_time as f64 / 1000.0)),
|
||||
]);
|
||||
|
||||
println!("{table}");
|
||||
|
||||
if !failed.is_empty() {
|
||||
println!("\n{}", style("Failed Files:").bold().red());
|
||||
|
||||
let mut failed_table = Table::new();
|
||||
failed_table
|
||||
.load_preset(UTF8_FULL)
|
||||
.apply_modifier(UTF8_ROUND_CORNERS)
|
||||
.set_header(vec![
|
||||
Cell::new("#").fg(Color::Cyan),
|
||||
Cell::new("File").fg(Color::Cyan),
|
||||
Cell::new("Confidence").fg(Color::Cyan),
|
||||
]);
|
||||
|
||||
for (i, result) in failed.iter().enumerate() {
|
||||
failed_table.add_row(vec![
|
||||
Cell::new((i + 1).to_string()),
|
||||
Cell::new(result.file.display().to_string()),
|
||||
Cell::new(format!("{:.2}%", result.confidence * 100.0)).fg(Color::Red),
|
||||
]);
|
||||
}
|
||||
|
||||
println!("{failed_table}");
|
||||
}
|
||||
|
||||
// Summary statistics
|
||||
println!("\n{}", style("Statistics:").bold().cyan());
|
||||
|
||||
if !passed.is_empty() {
|
||||
let confidences: Vec<f64> = passed.iter().map(|r| r.confidence).collect();
|
||||
let min_confidence = confidences.iter().cloned().fold(f64::INFINITY, f64::min);
|
||||
let max_confidence = confidences
|
||||
.iter()
|
||||
.cloned()
|
||||
.fold(f64::NEG_INFINITY, f64::max);
|
||||
|
||||
println!(
|
||||
" Min confidence: {}",
|
||||
style(format!("{:.2}%", min_confidence * 100.0)).green()
|
||||
);
|
||||
println!(
|
||||
" Max confidence: {}",
|
||||
style(format!("{:.2}%", max_confidence * 100.0)).green()
|
||||
);
|
||||
|
||||
let times: Vec<u64> = passed.iter().map(|r| r.processing_time_ms).collect();
|
||||
let min_time = times.iter().min().unwrap_or(&0);
|
||||
let max_time = times.iter().max().unwrap_or(&0);
|
||||
|
||||
println!(" Min processing time: {}ms", style(min_time).cyan());
|
||||
println!(" Max processing time: {}ms", style(max_time).cyan());
|
||||
}
|
||||
|
||||
println!();
|
||||
}
|
||||
|
||||
/// Get color based on confidence value
|
||||
fn confidence_color(confidence: f64) -> Color {
|
||||
if confidence >= 0.9 {
|
||||
Color::Green
|
||||
} else if confidence >= 0.7 {
|
||||
Color::Yellow
|
||||
} else {
|
||||
Color::Red
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a progress bar style for batch processing
|
||||
pub fn create_progress_style() -> indicatif::ProgressStyle {
|
||||
indicatif::ProgressStyle::default_bar()
|
||||
.template(
|
||||
"{spinner:.green} [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta}) {msg}",
|
||||
)
|
||||
.unwrap()
|
||||
.progress_chars("█▓▒░ ")
|
||||
}
|
||||
|
||||
/// Create a spinner style for individual file processing
|
||||
pub fn create_spinner_style() -> indicatif::ProgressStyle {
|
||||
indicatif::ProgressStyle::default_spinner()
|
||||
.template("{spinner:.cyan} {msg}")
|
||||
.unwrap()
|
||||
.tick_chars("⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏")
|
||||
}
|
||||
455
examples/scipix/src/config.rs
Normal file
455
examples/scipix/src/config.rs
Normal file
@@ -0,0 +1,455 @@
|
||||
//! Configuration system for Ruvector-Scipix
|
||||
//!
|
||||
//! Comprehensive configuration with TOML support, environment overrides, and validation.
|
||||
|
||||
use crate::error::{Result, ScipixError};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
|
||||
/// Main configuration structure
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Config {
|
||||
/// OCR processing configuration
|
||||
pub ocr: OcrConfig,
|
||||
|
||||
/// Model configuration
|
||||
pub model: ModelConfig,
|
||||
|
||||
/// Preprocessing configuration
|
||||
pub preprocess: PreprocessConfig,
|
||||
|
||||
/// Output format configuration
|
||||
pub output: OutputConfig,
|
||||
|
||||
/// Performance tuning
|
||||
pub performance: PerformanceConfig,
|
||||
|
||||
/// Cache configuration
|
||||
pub cache: CacheConfig,
|
||||
}
|
||||
|
||||
/// OCR engine configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OcrConfig {
|
||||
/// Confidence threshold (0.0-1.0)
|
||||
pub confidence_threshold: f32,
|
||||
|
||||
/// Maximum processing time in seconds
|
||||
pub timeout: u64,
|
||||
|
||||
/// Enable GPU acceleration
|
||||
pub use_gpu: bool,
|
||||
|
||||
/// Language codes (e.g., ["en", "es"])
|
||||
pub languages: Vec<String>,
|
||||
|
||||
/// Enable equation detection
|
||||
pub detect_equations: bool,
|
||||
|
||||
/// Enable table detection
|
||||
pub detect_tables: bool,
|
||||
}
|
||||
|
||||
/// Model configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelConfig {
|
||||
/// Path to OCR model
|
||||
pub model_path: String,
|
||||
|
||||
/// Model version
|
||||
pub version: String,
|
||||
|
||||
/// Batch size for processing
|
||||
pub batch_size: usize,
|
||||
|
||||
/// Model precision (fp16, fp32, int8)
|
||||
pub precision: String,
|
||||
|
||||
/// Enable quantization
|
||||
pub quantize: bool,
|
||||
}
|
||||
|
||||
/// Image preprocessing configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PreprocessConfig {
|
||||
/// Enable auto-rotation
|
||||
pub auto_rotate: bool,
|
||||
|
||||
/// Enable denoising
|
||||
pub denoise: bool,
|
||||
|
||||
/// Enable contrast enhancement
|
||||
pub enhance_contrast: bool,
|
||||
|
||||
/// Enable binarization
|
||||
pub binarize: bool,
|
||||
|
||||
/// Target DPI for scaling
|
||||
pub target_dpi: u32,
|
||||
|
||||
/// Maximum image dimension
|
||||
pub max_dimension: u32,
|
||||
}
|
||||
|
||||
/// Output format configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OutputConfig {
|
||||
/// Output formats (latex, mathml, asciimath)
|
||||
pub formats: Vec<String>,
|
||||
|
||||
/// Include confidence scores
|
||||
pub include_confidence: bool,
|
||||
|
||||
/// Include bounding boxes
|
||||
pub include_bbox: bool,
|
||||
|
||||
/// Pretty print JSON
|
||||
pub pretty_print: bool,
|
||||
|
||||
/// Include metadata
|
||||
pub include_metadata: bool,
|
||||
}
|
||||
|
||||
/// Performance tuning configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PerformanceConfig {
|
||||
/// Number of worker threads
|
||||
pub num_threads: usize,
|
||||
|
||||
/// Enable parallel processing
|
||||
pub parallel: bool,
|
||||
|
||||
/// Memory limit in MB
|
||||
pub memory_limit: usize,
|
||||
|
||||
/// Enable profiling
|
||||
pub profile: bool,
|
||||
}
|
||||
|
||||
/// Cache configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CacheConfig {
|
||||
/// Enable caching
|
||||
pub enabled: bool,
|
||||
|
||||
/// Cache capacity (number of entries)
|
||||
pub capacity: usize,
|
||||
|
||||
/// Similarity threshold for cache hits (0.0-1.0)
|
||||
pub similarity_threshold: f32,
|
||||
|
||||
/// Cache TTL in seconds
|
||||
pub ttl: u64,
|
||||
|
||||
/// Vector dimension for embeddings
|
||||
pub vector_dimension: usize,
|
||||
|
||||
/// Enable persistent cache
|
||||
pub persistent: bool,
|
||||
|
||||
/// Cache directory path
|
||||
pub cache_dir: String,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
ocr: OcrConfig {
|
||||
confidence_threshold: 0.7,
|
||||
timeout: 30,
|
||||
use_gpu: false,
|
||||
languages: vec!["en".to_string()],
|
||||
detect_equations: true,
|
||||
detect_tables: true,
|
||||
},
|
||||
model: ModelConfig {
|
||||
model_path: "models/scipix-ocr".to_string(),
|
||||
version: "1.0.0".to_string(),
|
||||
batch_size: 1,
|
||||
precision: "fp32".to_string(),
|
||||
quantize: false,
|
||||
},
|
||||
preprocess: PreprocessConfig {
|
||||
auto_rotate: true,
|
||||
denoise: true,
|
||||
enhance_contrast: true,
|
||||
binarize: false,
|
||||
target_dpi: 300,
|
||||
max_dimension: 4096,
|
||||
},
|
||||
output: OutputConfig {
|
||||
formats: vec!["latex".to_string()],
|
||||
include_confidence: true,
|
||||
include_bbox: false,
|
||||
pretty_print: true,
|
||||
include_metadata: false,
|
||||
},
|
||||
performance: PerformanceConfig {
|
||||
num_threads: num_cpus::get(),
|
||||
parallel: true,
|
||||
memory_limit: 2048,
|
||||
profile: false,
|
||||
},
|
||||
cache: CacheConfig {
|
||||
enabled: true,
|
||||
capacity: 1000,
|
||||
similarity_threshold: 0.95,
|
||||
ttl: 3600,
|
||||
vector_dimension: 512,
|
||||
persistent: false,
|
||||
cache_dir: ".cache/scipix".to_string(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Load configuration from TOML file
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to TOML configuration file
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use ruvector_scipix::Config;
|
||||
///
|
||||
/// let config = Config::from_file("scipix.toml").unwrap();
|
||||
/// ```
|
||||
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
|
||||
let content = std::fs::read_to_string(path)?;
|
||||
let config: Config = toml::from_str(&content)?;
|
||||
config.validate()?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Save configuration to TOML file
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to save TOML configuration
|
||||
pub fn to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
|
||||
let content = toml::to_string_pretty(self)?;
|
||||
std::fs::write(path, content)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load configuration from environment variables
|
||||
///
|
||||
/// Environment variables should be prefixed with `MATHPIX_`
|
||||
/// and use double underscores for nested fields.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```bash
|
||||
/// export MATHPIX_OCR__CONFIDENCE_THRESHOLD=0.8
|
||||
/// export MATHPIX_MODEL__BATCH_SIZE=4
|
||||
/// ```
|
||||
pub fn from_env() -> Result<Self> {
|
||||
let mut config = Self::default();
|
||||
config.apply_env_overrides()?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Apply environment variable overrides
|
||||
fn apply_env_overrides(&mut self) -> Result<()> {
|
||||
// OCR overrides
|
||||
if let Ok(val) = std::env::var("MATHPIX_OCR__CONFIDENCE_THRESHOLD") {
|
||||
self.ocr.confidence_threshold = val
|
||||
.parse()
|
||||
.map_err(|_| ScipixError::Config("Invalid confidence_threshold".to_string()))?;
|
||||
}
|
||||
if let Ok(val) = std::env::var("MATHPIX_OCR__TIMEOUT") {
|
||||
self.ocr.timeout = val
|
||||
.parse()
|
||||
.map_err(|_| ScipixError::Config("Invalid timeout".to_string()))?;
|
||||
}
|
||||
if let Ok(val) = std::env::var("MATHPIX_OCR__USE_GPU") {
|
||||
self.ocr.use_gpu = val
|
||||
.parse()
|
||||
.map_err(|_| ScipixError::Config("Invalid use_gpu".to_string()))?;
|
||||
}
|
||||
|
||||
// Model overrides
|
||||
if let Ok(val) = std::env::var("MATHPIX_MODEL__PATH") {
|
||||
self.model.model_path = val;
|
||||
}
|
||||
if let Ok(val) = std::env::var("MATHPIX_MODEL__BATCH_SIZE") {
|
||||
self.model.batch_size = val
|
||||
.parse()
|
||||
.map_err(|_| ScipixError::Config("Invalid batch_size".to_string()))?;
|
||||
}
|
||||
|
||||
// Cache overrides
|
||||
if let Ok(val) = std::env::var("MATHPIX_CACHE__ENABLED") {
|
||||
self.cache.enabled = val
|
||||
.parse()
|
||||
.map_err(|_| ScipixError::Config("Invalid cache enabled".to_string()))?;
|
||||
}
|
||||
if let Ok(val) = std::env::var("MATHPIX_CACHE__CAPACITY") {
|
||||
self.cache.capacity = val
|
||||
.parse()
|
||||
.map_err(|_| ScipixError::Config("Invalid cache capacity".to_string()))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate configuration
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
// Validate confidence threshold
|
||||
if self.ocr.confidence_threshold < 0.0 || self.ocr.confidence_threshold > 1.0 {
|
||||
return Err(ScipixError::Config(
|
||||
"confidence_threshold must be between 0.0 and 1.0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Validate similarity threshold
|
||||
if self.cache.similarity_threshold < 0.0 || self.cache.similarity_threshold > 1.0 {
|
||||
return Err(ScipixError::Config(
|
||||
"similarity_threshold must be between 0.0 and 1.0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Validate batch size
|
||||
if self.model.batch_size == 0 {
|
||||
return Err(ScipixError::Config(
|
||||
"batch_size must be greater than 0".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Validate precision
|
||||
let valid_precisions = ["fp16", "fp32", "int8"];
|
||||
if !valid_precisions.contains(&self.model.precision.as_str()) {
|
||||
return Err(ScipixError::Config(format!(
|
||||
"precision must be one of: {:?}",
|
||||
valid_precisions
|
||||
)));
|
||||
}
|
||||
|
||||
// Validate output formats
|
||||
let valid_formats = ["latex", "mathml", "asciimath"];
|
||||
for format in &self.output.formats {
|
||||
if !valid_formats.contains(&format.as_str()) {
|
||||
return Err(ScipixError::Config(format!(
|
||||
"Invalid output format: {}. Must be one of: {:?}",
|
||||
format, valid_formats
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create high-accuracy preset configuration
|
||||
pub fn high_accuracy() -> Self {
|
||||
let mut config = Self::default();
|
||||
config.ocr.confidence_threshold = 0.9;
|
||||
config.model.precision = "fp32".to_string();
|
||||
config.model.quantize = false;
|
||||
config.preprocess.denoise = true;
|
||||
config.preprocess.enhance_contrast = true;
|
||||
config.cache.similarity_threshold = 0.98;
|
||||
config
|
||||
}
|
||||
|
||||
/// Create high-speed preset configuration
|
||||
pub fn high_speed() -> Self {
|
||||
let mut config = Self::default();
|
||||
config.ocr.confidence_threshold = 0.6;
|
||||
config.model.precision = "fp16".to_string();
|
||||
config.model.quantize = true;
|
||||
config.model.batch_size = 4;
|
||||
config.preprocess.denoise = false;
|
||||
config.preprocess.enhance_contrast = false;
|
||||
config.performance.parallel = true;
|
||||
config.cache.similarity_threshold = 0.85;
|
||||
config
|
||||
}
|
||||
|
||||
/// Create minimal configuration
|
||||
pub fn minimal() -> Self {
|
||||
let mut config = Self::default();
|
||||
config.cache.enabled = false;
|
||||
config.preprocess.denoise = false;
|
||||
config.preprocess.enhance_contrast = false;
|
||||
config.performance.parallel = false;
|
||||
config
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = Config::default();
|
||||
assert!(config.validate().is_ok());
|
||||
assert_eq!(config.ocr.confidence_threshold, 0.7);
|
||||
assert!(config.cache.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_high_accuracy_config() {
|
||||
let config = Config::high_accuracy();
|
||||
assert!(config.validate().is_ok());
|
||||
assert_eq!(config.ocr.confidence_threshold, 0.9);
|
||||
assert_eq!(config.cache.similarity_threshold, 0.98);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_high_speed_config() {
|
||||
let config = Config::high_speed();
|
||||
assert!(config.validate().is_ok());
|
||||
assert_eq!(config.model.precision, "fp16");
|
||||
assert!(config.model.quantize);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_minimal_config() {
|
||||
let config = Config::minimal();
|
||||
assert!(config.validate().is_ok());
|
||||
assert!(!config.cache.enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_confidence_threshold() {
|
||||
let mut config = Config::default();
|
||||
config.ocr.confidence_threshold = 1.5;
|
||||
assert!(config.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_batch_size() {
|
||||
let mut config = Config::default();
|
||||
config.model.batch_size = 0;
|
||||
assert!(config.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_precision() {
|
||||
let mut config = Config::default();
|
||||
config.model.precision = "invalid".to_string();
|
||||
assert!(config.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_output_format() {
|
||||
let mut config = Config::default();
|
||||
config.output.formats = vec!["invalid".to_string()];
|
||||
assert!(config.validate().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_toml_serialization() {
|
||||
let config = Config::default();
|
||||
let toml_str = toml::to_string(&config).unwrap();
|
||||
let deserialized: Config = toml::from_str(&toml_str).unwrap();
|
||||
assert_eq!(
|
||||
config.ocr.confidence_threshold,
|
||||
deserialized.ocr.confidence_threshold
|
||||
);
|
||||
}
|
||||
}
|
||||
228
examples/scipix/src/error.rs
Normal file
228
examples/scipix/src/error.rs
Normal file
@@ -0,0 +1,228 @@
|
||||
//! Error types for Ruvector-Scipix
|
||||
//!
|
||||
//! Comprehensive error handling with context, HTTP status mapping, and retry logic.
|
||||
|
||||
use std::io;
|
||||
use thiserror::Error;
|
||||
|
||||
/// Result type alias for Scipix operations
|
||||
pub type Result<T> = std::result::Result<T, ScipixError>;
|
||||
|
||||
/// Comprehensive error types for all Scipix operations
|
||||
#[derive(Debug, Error)]
|
||||
pub enum ScipixError {
|
||||
/// Image loading or processing error
|
||||
#[error("Image error: {0}")]
|
||||
Image(String),
|
||||
|
||||
/// Machine learning model error
|
||||
#[error("Model error: {0}")]
|
||||
Model(String),
|
||||
|
||||
/// OCR processing error
|
||||
#[error("OCR error: {0}")]
|
||||
Ocr(String),
|
||||
|
||||
/// LaTeX generation or parsing error
|
||||
#[error("LaTeX error: {0}")]
|
||||
LaTeX(String),
|
||||
|
||||
/// Configuration error
|
||||
#[error("Configuration error: {0}")]
|
||||
Config(String),
|
||||
|
||||
/// I/O error
|
||||
#[error("I/O error: {0}")]
|
||||
Io(#[from] io::Error),
|
||||
|
||||
/// Serialization/deserialization error
|
||||
#[error("Serialization error: {0}")]
|
||||
Serialization(String),
|
||||
|
||||
/// Invalid input error
|
||||
#[error("Invalid input: {0}")]
|
||||
InvalidInput(String),
|
||||
|
||||
/// Operation timeout
|
||||
#[error("Timeout: operation took longer than {0}s")]
|
||||
Timeout(u64),
|
||||
|
||||
/// Resource not found
|
||||
#[error("Not found: {0}")]
|
||||
NotFound(String),
|
||||
|
||||
/// Authentication error
|
||||
#[error("Authentication error: {0}")]
|
||||
Auth(String),
|
||||
|
||||
/// Rate limit exceeded
|
||||
#[error("Rate limit exceeded: {0}")]
|
||||
RateLimit(String),
|
||||
|
||||
/// Internal error
|
||||
#[error("Internal error: {0}")]
|
||||
Internal(String),
|
||||
}
|
||||
|
||||
impl ScipixError {
|
||||
/// Check if the error is retryable
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `true` if the operation should be retried, `false` otherwise
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use ruvector_scipix::ScipixError;
|
||||
///
|
||||
/// let timeout_error = ScipixError::Timeout(30);
|
||||
/// assert!(timeout_error.is_retryable());
|
||||
///
|
||||
/// let config_error = ScipixError::Config("Invalid parameter".to_string());
|
||||
/// assert!(!config_error.is_retryable());
|
||||
/// ```
|
||||
pub fn is_retryable(&self) -> bool {
|
||||
match self {
|
||||
// Retryable errors
|
||||
ScipixError::Timeout(_) => true,
|
||||
ScipixError::RateLimit(_) => true,
|
||||
ScipixError::Io(_) => true,
|
||||
ScipixError::Internal(_) => true,
|
||||
|
||||
// Non-retryable errors
|
||||
ScipixError::Image(_) => false,
|
||||
ScipixError::Model(_) => false,
|
||||
ScipixError::Ocr(_) => false,
|
||||
ScipixError::LaTeX(_) => false,
|
||||
ScipixError::Config(_) => false,
|
||||
ScipixError::Serialization(_) => false,
|
||||
ScipixError::InvalidInput(_) => false,
|
||||
ScipixError::NotFound(_) => false,
|
||||
ScipixError::Auth(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Map error to HTTP status code
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// HTTP status code representing the error type
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use ruvector_scipix::ScipixError;
|
||||
///
|
||||
/// let auth_error = ScipixError::Auth("Invalid token".to_string());
|
||||
/// assert_eq!(auth_error.status_code(), 401);
|
||||
///
|
||||
/// let not_found = ScipixError::NotFound("Model not found".to_string());
|
||||
/// assert_eq!(not_found.status_code(), 404);
|
||||
/// ```
|
||||
pub fn status_code(&self) -> u16 {
|
||||
match self {
|
||||
ScipixError::Auth(_) => 401,
|
||||
ScipixError::NotFound(_) => 404,
|
||||
ScipixError::InvalidInput(_) => 400,
|
||||
ScipixError::RateLimit(_) => 429,
|
||||
ScipixError::Timeout(_) => 408,
|
||||
ScipixError::Config(_) => 400,
|
||||
ScipixError::Internal(_) => 500,
|
||||
_ => 500,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get error category for logging and metrics
|
||||
pub fn category(&self) -> &'static str {
|
||||
match self {
|
||||
ScipixError::Image(_) => "image",
|
||||
ScipixError::Model(_) => "model",
|
||||
ScipixError::Ocr(_) => "ocr",
|
||||
ScipixError::LaTeX(_) => "latex",
|
||||
ScipixError::Config(_) => "config",
|
||||
ScipixError::Io(_) => "io",
|
||||
ScipixError::Serialization(_) => "serialization",
|
||||
ScipixError::InvalidInput(_) => "invalid_input",
|
||||
ScipixError::Timeout(_) => "timeout",
|
||||
ScipixError::NotFound(_) => "not_found",
|
||||
ScipixError::Auth(_) => "auth",
|
||||
ScipixError::RateLimit(_) => "rate_limit",
|
||||
ScipixError::Internal(_) => "internal",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Conversion from serde_json::Error
|
||||
impl From<serde_json::Error> for ScipixError {
|
||||
fn from(err: serde_json::Error) -> Self {
|
||||
ScipixError::Serialization(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
// Conversion from toml::de::Error
|
||||
impl From<toml::de::Error> for ScipixError {
|
||||
fn from(err: toml::de::Error) -> Self {
|
||||
ScipixError::Config(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
// Conversion from toml::ser::Error
|
||||
impl From<toml::ser::Error> for ScipixError {
|
||||
fn from(err: toml::ser::Error) -> Self {
|
||||
ScipixError::Serialization(err.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_error_display() {
|
||||
let err = ScipixError::Image("Failed to load".to_string());
|
||||
assert_eq!(err.to_string(), "Image error: Failed to load");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_retryable() {
|
||||
assert!(ScipixError::Timeout(30).is_retryable());
|
||||
assert!(ScipixError::RateLimit("Exceeded".to_string()).is_retryable());
|
||||
assert!(!ScipixError::Config("Invalid".to_string()).is_retryable());
|
||||
assert!(!ScipixError::Auth("Unauthorized".to_string()).is_retryable());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_status_codes() {
|
||||
assert_eq!(ScipixError::Auth("".to_string()).status_code(), 401);
|
||||
assert_eq!(ScipixError::NotFound("".to_string()).status_code(), 404);
|
||||
assert_eq!(ScipixError::InvalidInput("".to_string()).status_code(), 400);
|
||||
assert_eq!(ScipixError::RateLimit("".to_string()).status_code(), 429);
|
||||
assert_eq!(ScipixError::Timeout(0).status_code(), 408);
|
||||
assert_eq!(ScipixError::Internal("".to_string()).status_code(), 500);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_category() {
|
||||
assert_eq!(ScipixError::Image("".to_string()).category(), "image");
|
||||
assert_eq!(ScipixError::Model("".to_string()).category(), "model");
|
||||
assert_eq!(ScipixError::Ocr("".to_string()).category(), "ocr");
|
||||
assert_eq!(ScipixError::LaTeX("".to_string()).category(), "latex");
|
||||
assert_eq!(ScipixError::Config("".to_string()).category(), "config");
|
||||
assert_eq!(ScipixError::Auth("".to_string()).category(), "auth");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_io_error() {
|
||||
let io_err = io::Error::new(io::ErrorKind::NotFound, "File not found");
|
||||
let scipix_err: ScipixError = io_err.into();
|
||||
assert!(matches!(scipix_err, ScipixError::Io(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_json_error() {
|
||||
let json_err = serde_json::from_str::<serde_json::Value>("invalid json").unwrap_err();
|
||||
let scipix_err: ScipixError = json_err.into();
|
||||
assert!(matches!(scipix_err, ScipixError::Serialization(_)));
|
||||
}
|
||||
}
|
||||
129
examples/scipix/src/lib.rs
Normal file
129
examples/scipix/src/lib.rs
Normal file
@@ -0,0 +1,129 @@
|
||||
//! # Ruvector-Scipix
|
||||
//!
|
||||
//! A high-performance Rust implementation of Scipix OCR for mathematical expressions and equations.
|
||||
//! Built on top of ruvector-core for efficient vector-based caching and similarity search.
|
||||
//!
|
||||
//! ## Features
|
||||
//!
|
||||
//! - **Mathematical OCR**: Extract LaTeX from images of equations
|
||||
//! - **Vector Caching**: Intelligent caching using image embeddings
|
||||
//! - **Multiple Formats**: Support for LaTeX, MathML, AsciiMath
|
||||
//! - **High Performance**: Parallel processing and efficient caching
|
||||
//! - **Configurable**: Extensive configuration options via TOML or API
|
||||
//!
|
||||
//! ## Quick Start
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
//! use ruvector_scipix::{Config, OcrEngine, Result};
|
||||
//!
|
||||
//! #[tokio::main]
|
||||
//! async fn main() -> Result<()> {
|
||||
//! // Load configuration
|
||||
//! let config = Config::from_file("scipix.toml")?;
|
||||
//!
|
||||
//! // Create OCR engine
|
||||
//! let engine = OcrEngine::new(config).await?;
|
||||
//!
|
||||
//! // Process image
|
||||
//! let result = engine.process_image("equation.png").await?;
|
||||
//! println!("LaTeX: {}", result.latex);
|
||||
//!
|
||||
//! Ok(())
|
||||
//! }
|
||||
//! ```
|
||||
//!
|
||||
//! ## Architecture
|
||||
//!
|
||||
//! - **config**: Configuration management with TOML support
|
||||
//! - **error**: Comprehensive error types with context
|
||||
//! - **math**: LaTeX and mathematical format handling
|
||||
//! - **ocr**: Core OCR processing engine
|
||||
//! - **output**: Output formatting and serialization
|
||||
//! - **preprocess**: Image preprocessing pipeline
|
||||
//! - **cache**: Vector-based intelligent caching
|
||||
|
||||
// Module declarations
|
||||
pub mod api;
|
||||
pub mod cli;
|
||||
pub mod config;
|
||||
pub mod error;
|
||||
|
||||
#[cfg(feature = "cache")]
|
||||
pub mod cache;
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
pub mod ocr;
|
||||
|
||||
#[cfg(feature = "math")]
|
||||
pub mod math;
|
||||
|
||||
#[cfg(feature = "preprocess")]
|
||||
pub mod preprocess;
|
||||
|
||||
// Output module is always available
|
||||
pub mod output;
|
||||
|
||||
// Performance optimizations
|
||||
#[cfg(feature = "optimize")]
|
||||
pub mod optimize;
|
||||
|
||||
// WebAssembly bindings
|
||||
#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
|
||||
pub mod wasm;
|
||||
|
||||
// Public re-exports
|
||||
pub use api::{state::AppState, ApiServer};
|
||||
pub use cli::{Cli, Commands};
|
||||
pub use config::{
|
||||
CacheConfig, Config, ModelConfig, OcrConfig, OutputConfig, PerformanceConfig, PreprocessConfig,
|
||||
};
|
||||
pub use error::{Result, ScipixError};
|
||||
|
||||
#[cfg(feature = "cache")]
|
||||
pub use cache::CacheManager;
|
||||
|
||||
/// Library version
|
||||
pub const VERSION: &str = env!("CARGO_PKG_VERSION");
|
||||
|
||||
/// Default configuration preset
|
||||
pub fn default_config() -> Config {
|
||||
Config::default()
|
||||
}
|
||||
|
||||
/// High-accuracy configuration preset
|
||||
pub fn high_accuracy_config() -> Config {
|
||||
Config::high_accuracy()
|
||||
}
|
||||
|
||||
/// High-speed configuration preset
|
||||
pub fn high_speed_config() -> Config {
|
||||
Config::high_speed()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_version() {
|
||||
assert!(!VERSION.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = default_config();
|
||||
assert!(config.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_high_accuracy_config() {
|
||||
let config = high_accuracy_config();
|
||||
assert!(config.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_high_speed_config() {
|
||||
let config = high_speed_config();
|
||||
assert!(config.validate().is_ok());
|
||||
}
|
||||
}
|
||||
465
examples/scipix/src/math/asciimath.rs
Normal file
465
examples/scipix/src/math/asciimath.rs
Normal file
@@ -0,0 +1,465 @@
|
||||
//! AsciiMath generation from mathematical AST
|
||||
//!
|
||||
//! This module converts mathematical AST nodes to AsciiMath notation,
|
||||
//! a simplified plain-text format for mathematical expressions.
|
||||
|
||||
use crate::math::ast::{BinaryOp, BracketType, LargeOpType, MathExpr, MathNode, UnaryOp};
|
||||
|
||||
/// AsciiMath generator for mathematical expressions
|
||||
pub struct AsciiMathGenerator {
|
||||
/// Use Unicode symbols (true) or ASCII approximations (false)
|
||||
unicode: bool,
|
||||
}
|
||||
|
||||
impl AsciiMathGenerator {
|
||||
/// Create a new AsciiMath generator with Unicode support
|
||||
pub fn new() -> Self {
|
||||
Self { unicode: true }
|
||||
}
|
||||
|
||||
/// Create an ASCII-only generator
|
||||
pub fn ascii_only() -> Self {
|
||||
Self { unicode: false }
|
||||
}
|
||||
|
||||
/// Generate AsciiMath string from a mathematical expression
|
||||
pub fn generate(&self, expr: &MathExpr) -> String {
|
||||
self.generate_node(&expr.root, None)
|
||||
}
|
||||
|
||||
/// Generate AsciiMath for a single node
|
||||
fn generate_node(&self, node: &MathNode, parent_precedence: Option<u8>) -> String {
|
||||
match node {
|
||||
MathNode::Symbol { value, .. } => value.clone(),
|
||||
|
||||
MathNode::Number { value, .. } => value.clone(),
|
||||
|
||||
MathNode::Binary { op, left, right } => {
|
||||
let precedence = op.precedence();
|
||||
let needs_parens = parent_precedence.map_or(false, |p| precedence < p);
|
||||
|
||||
let left_str = self.generate_node(left, Some(precedence));
|
||||
let right_str = self.generate_node(
|
||||
right,
|
||||
Some(if op.is_left_associative() {
|
||||
precedence
|
||||
} else {
|
||||
precedence + 1
|
||||
}),
|
||||
);
|
||||
|
||||
let op_str = self.binary_op_to_asciimath(op);
|
||||
|
||||
let result = format!("{} {} {}", left_str, op_str, right_str);
|
||||
|
||||
if needs_parens {
|
||||
format!("({})", result)
|
||||
} else {
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
MathNode::Unary { op, operand } => {
|
||||
let op_str = self.unary_op_to_asciimath(op);
|
||||
let operand_str = self.generate_node(operand, Some(70));
|
||||
format!("{}{}", op_str, operand_str)
|
||||
}
|
||||
|
||||
MathNode::Fraction {
|
||||
numerator,
|
||||
denominator,
|
||||
} => {
|
||||
let num_str = self.generate_node(numerator, None);
|
||||
let den_str = self.generate_node(denominator, None);
|
||||
format!("({})/({})", num_str, den_str)
|
||||
}
|
||||
|
||||
MathNode::Radical { index, radicand } => {
|
||||
let rad_str = self.generate_node(radicand, None);
|
||||
if let Some(idx) = index {
|
||||
let idx_str = self.generate_node(idx, None);
|
||||
format!("root({})({} )", idx_str, rad_str)
|
||||
} else {
|
||||
format!("sqrt({})", rad_str)
|
||||
}
|
||||
}
|
||||
|
||||
MathNode::Script {
|
||||
base,
|
||||
subscript,
|
||||
superscript,
|
||||
} => {
|
||||
let base_str = self.generate_node(base, Some(65));
|
||||
let mut result = base_str;
|
||||
|
||||
if let Some(sub) = subscript {
|
||||
let sub_str = self.generate_node(sub, None);
|
||||
result.push_str(&format!("_{{{}}}", sub_str));
|
||||
}
|
||||
|
||||
if let Some(sup) = superscript {
|
||||
let sup_str = self.generate_node(sup, None);
|
||||
result.push_str(&format!("^{{{}}}", sup_str));
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
MathNode::Function { name, argument } => {
|
||||
let arg_str = self.generate_node(argument, None);
|
||||
format!("{}({})", name, arg_str)
|
||||
}
|
||||
|
||||
MathNode::Matrix { rows, .. } => {
|
||||
let mut content = String::new();
|
||||
content.push('[');
|
||||
|
||||
for (i, row) in rows.iter().enumerate() {
|
||||
if i > 0 {
|
||||
content.push_str("; ");
|
||||
}
|
||||
for (j, elem) in row.iter().enumerate() {
|
||||
if j > 0 {
|
||||
content.push_str(", ");
|
||||
}
|
||||
content.push_str(&self.generate_node(elem, None));
|
||||
}
|
||||
}
|
||||
|
||||
content.push(']');
|
||||
content
|
||||
}
|
||||
|
||||
MathNode::Group {
|
||||
content,
|
||||
bracket_type,
|
||||
} => {
|
||||
let content_str = self.generate_node(content, None);
|
||||
let (open, close) = match bracket_type {
|
||||
BracketType::Parentheses => ("(", ")"),
|
||||
BracketType::Brackets => ("[", "]"),
|
||||
BracketType::Braces => ("{", "}"),
|
||||
BracketType::AngleBrackets => {
|
||||
if self.unicode {
|
||||
("⟨", "⟩")
|
||||
} else {
|
||||
("<", ">")
|
||||
}
|
||||
}
|
||||
BracketType::Vertical => ("|", "|"),
|
||||
BracketType::DoubleVertical => {
|
||||
if self.unicode {
|
||||
("‖", "‖")
|
||||
} else {
|
||||
("||", "||")
|
||||
}
|
||||
}
|
||||
BracketType::Floor => {
|
||||
if self.unicode {
|
||||
("⌊", "⌋")
|
||||
} else {
|
||||
("|_", "_|")
|
||||
}
|
||||
}
|
||||
BracketType::Ceiling => {
|
||||
if self.unicode {
|
||||
("⌈", "⌉")
|
||||
} else {
|
||||
("|^", "^|")
|
||||
}
|
||||
}
|
||||
BracketType::None => ("", ""),
|
||||
};
|
||||
|
||||
format!("{}{}{}", open, content_str, close)
|
||||
}
|
||||
|
||||
MathNode::LargeOp {
|
||||
op_type,
|
||||
lower,
|
||||
upper,
|
||||
content,
|
||||
} => {
|
||||
let op_str = self.large_op_to_asciimath(op_type);
|
||||
let content_str = self.generate_node(content, None);
|
||||
|
||||
let mut result = op_str.to_string();
|
||||
|
||||
if let Some(low) = lower {
|
||||
let low_str = self.generate_node(low, None);
|
||||
result.push_str(&format!("_{{{}}}", low_str));
|
||||
}
|
||||
|
||||
if let Some(up) = upper {
|
||||
let up_str = self.generate_node(up, None);
|
||||
result.push_str(&format!("^{{{}}}", up_str));
|
||||
}
|
||||
|
||||
format!("{} {}", result, content_str)
|
||||
}
|
||||
|
||||
MathNode::Sequence { elements } => elements
|
||||
.iter()
|
||||
.map(|e| self.generate_node(e, None))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", "),
|
||||
|
||||
MathNode::Text { content } => {
|
||||
format!("\"{}\"", content)
|
||||
}
|
||||
|
||||
MathNode::Empty => String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert binary operator to AsciiMath
|
||||
fn binary_op_to_asciimath<'a>(&self, op: &'a BinaryOp) -> &'a str {
|
||||
if self.unicode {
|
||||
match op {
|
||||
BinaryOp::Add => "+",
|
||||
BinaryOp::Subtract => "-",
|
||||
BinaryOp::Multiply => "×",
|
||||
BinaryOp::Divide => "÷",
|
||||
BinaryOp::Power => "^",
|
||||
BinaryOp::Equal => "=",
|
||||
BinaryOp::NotEqual => "≠",
|
||||
BinaryOp::Less => "<",
|
||||
BinaryOp::Greater => ">",
|
||||
BinaryOp::LessEqual => "≤",
|
||||
BinaryOp::GreaterEqual => "≥",
|
||||
BinaryOp::ApproxEqual => "≈",
|
||||
BinaryOp::Equivalent => "≡",
|
||||
BinaryOp::Similar => "∼",
|
||||
BinaryOp::Congruent => "≅",
|
||||
BinaryOp::Proportional => "∝",
|
||||
BinaryOp::Custom(s) => s,
|
||||
}
|
||||
} else {
|
||||
match op {
|
||||
BinaryOp::Add => "+",
|
||||
BinaryOp::Subtract => "-",
|
||||
BinaryOp::Multiply => "*",
|
||||
BinaryOp::Divide => "/",
|
||||
BinaryOp::Power => "^",
|
||||
BinaryOp::Equal => "=",
|
||||
BinaryOp::NotEqual => "!=",
|
||||
BinaryOp::Less => "<",
|
||||
BinaryOp::Greater => ">",
|
||||
BinaryOp::LessEqual => "<=",
|
||||
BinaryOp::GreaterEqual => ">=",
|
||||
BinaryOp::ApproxEqual => "~~",
|
||||
BinaryOp::Equivalent => "-=",
|
||||
BinaryOp::Similar => "~",
|
||||
BinaryOp::Congruent => "~=",
|
||||
BinaryOp::Proportional => "prop",
|
||||
BinaryOp::Custom(s) => s.as_str(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert unary operator to AsciiMath
|
||||
fn unary_op_to_asciimath<'a>(&self, op: &'a UnaryOp) -> &'a str {
|
||||
match op {
|
||||
UnaryOp::Plus => "+",
|
||||
UnaryOp::Minus => "-",
|
||||
UnaryOp::Not => {
|
||||
if self.unicode {
|
||||
"¬"
|
||||
} else {
|
||||
"not "
|
||||
}
|
||||
}
|
||||
UnaryOp::Custom(s) => s.as_str(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert large operator to AsciiMath
|
||||
fn large_op_to_asciimath(&self, op: &LargeOpType) -> &str {
|
||||
if self.unicode {
|
||||
match op {
|
||||
LargeOpType::Sum => "∑",
|
||||
LargeOpType::Product => "∏",
|
||||
LargeOpType::Integral => "∫",
|
||||
LargeOpType::DoubleIntegral => "∬",
|
||||
LargeOpType::TripleIntegral => "∭",
|
||||
LargeOpType::ContourIntegral => "∮",
|
||||
LargeOpType::Union => "⋃",
|
||||
LargeOpType::Intersection => "⋂",
|
||||
LargeOpType::Coproduct => "∐",
|
||||
LargeOpType::DirectSum => "⊕",
|
||||
LargeOpType::Custom(_) => "sum",
|
||||
}
|
||||
} else {
|
||||
match op {
|
||||
LargeOpType::Sum => "sum",
|
||||
LargeOpType::Product => "prod",
|
||||
LargeOpType::Integral => "int",
|
||||
LargeOpType::DoubleIntegral => "iint",
|
||||
LargeOpType::TripleIntegral => "iiint",
|
||||
LargeOpType::ContourIntegral => "oint",
|
||||
LargeOpType::Union => "cup",
|
||||
LargeOpType::Intersection => "cap",
|
||||
LargeOpType::Coproduct => "coprod",
|
||||
LargeOpType::DirectSum => "oplus",
|
||||
LargeOpType::Custom(_) => "sum",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AsciiMathGenerator {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_simple_number() {
|
||||
let expr = MathExpr::new(
|
||||
MathNode::Number {
|
||||
value: "42".to_string(),
|
||||
is_decimal: false,
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
let gen = AsciiMathGenerator::new();
|
||||
assert_eq!(gen.generate(&expr), "42");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_addition() {
|
||||
let expr = MathExpr::new(
|
||||
MathNode::Binary {
|
||||
op: BinaryOp::Add,
|
||||
left: Box::new(MathNode::Number {
|
||||
value: "1".to_string(),
|
||||
is_decimal: false,
|
||||
}),
|
||||
right: Box::new(MathNode::Number {
|
||||
value: "2".to_string(),
|
||||
is_decimal: false,
|
||||
}),
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
let gen = AsciiMathGenerator::new();
|
||||
assert_eq!(gen.generate(&expr), "1 + 2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fraction() {
|
||||
let expr = MathExpr::new(
|
||||
MathNode::Fraction {
|
||||
numerator: Box::new(MathNode::Number {
|
||||
value: "1".to_string(),
|
||||
is_decimal: false,
|
||||
}),
|
||||
denominator: Box::new(MathNode::Number {
|
||||
value: "2".to_string(),
|
||||
is_decimal: false,
|
||||
}),
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
let gen = AsciiMathGenerator::new();
|
||||
assert_eq!(gen.generate(&expr), "(1)/(2)");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sqrt() {
|
||||
let expr = MathExpr::new(
|
||||
MathNode::Radical {
|
||||
index: None,
|
||||
radicand: Box::new(MathNode::Number {
|
||||
value: "2".to_string(),
|
||||
is_decimal: false,
|
||||
}),
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
let gen = AsciiMathGenerator::new();
|
||||
assert_eq!(gen.generate(&expr), "sqrt(2)");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_superscript() {
|
||||
let expr = MathExpr::new(
|
||||
MathNode::Script {
|
||||
base: Box::new(MathNode::Symbol {
|
||||
value: "x".to_string(),
|
||||
unicode: Some('x'),
|
||||
}),
|
||||
subscript: None,
|
||||
superscript: Some(Box::new(MathNode::Number {
|
||||
value: "2".to_string(),
|
||||
is_decimal: false,
|
||||
})),
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
let gen = AsciiMathGenerator::new();
|
||||
assert_eq!(gen.generate(&expr), "x^{2}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unicode_vs_ascii() {
|
||||
let expr = MathExpr::new(
|
||||
MathNode::Binary {
|
||||
op: BinaryOp::Multiply,
|
||||
left: Box::new(MathNode::Number {
|
||||
value: "2".to_string(),
|
||||
is_decimal: false,
|
||||
}),
|
||||
right: Box::new(MathNode::Number {
|
||||
value: "3".to_string(),
|
||||
is_decimal: false,
|
||||
}),
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
|
||||
let gen_unicode = AsciiMathGenerator::new();
|
||||
assert_eq!(gen_unicode.generate(&expr), "2 × 3");
|
||||
|
||||
let gen_ascii = AsciiMathGenerator::ascii_only();
|
||||
assert_eq!(gen_ascii.generate(&expr), "2 * 3");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matrix() {
|
||||
let expr = MathExpr::new(
|
||||
MathNode::Matrix {
|
||||
rows: vec![
|
||||
vec![
|
||||
MathNode::Number {
|
||||
value: "1".to_string(),
|
||||
is_decimal: false,
|
||||
},
|
||||
MathNode::Number {
|
||||
value: "2".to_string(),
|
||||
is_decimal: false,
|
||||
},
|
||||
],
|
||||
vec![
|
||||
MathNode::Number {
|
||||
value: "3".to_string(),
|
||||
is_decimal: false,
|
||||
},
|
||||
MathNode::Number {
|
||||
value: "4".to_string(),
|
||||
is_decimal: false,
|
||||
},
|
||||
],
|
||||
],
|
||||
bracket_type: BracketType::Brackets,
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
let gen = AsciiMathGenerator::new();
|
||||
assert_eq!(gen.generate(&expr), "[1, 2; 3, 4]");
|
||||
}
|
||||
}
|
||||
437
examples/scipix/src/math/ast.rs
Normal file
437
examples/scipix/src/math/ast.rs
Normal file
@@ -0,0 +1,437 @@
|
||||
//! Abstract Syntax Tree definitions for mathematical expressions
|
||||
//!
|
||||
//! This module defines the complete AST structure for representing mathematical
|
||||
//! expressions including symbols, operators, fractions, matrices, and more.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
|
||||
/// A complete mathematical expression with confidence score
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct MathExpr {
|
||||
/// Root node of the expression tree
|
||||
pub root: MathNode,
|
||||
/// Confidence score (0.0 to 1.0) from OCR recognition
|
||||
pub confidence: f32,
|
||||
}
|
||||
|
||||
impl MathExpr {
|
||||
/// Create a new mathematical expression
|
||||
pub fn new(root: MathNode, confidence: f32) -> Self {
|
||||
Self { root, confidence }
|
||||
}
|
||||
|
||||
/// Accept a visitor for tree traversal
|
||||
pub fn accept<V: MathVisitor>(&self, visitor: &mut V) {
|
||||
self.root.accept(visitor);
|
||||
}
|
||||
}
|
||||
|
||||
/// Main AST node representing any mathematical construct
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum MathNode {
|
||||
/// A mathematical symbol (variable, Greek letter, operator)
|
||||
Symbol {
|
||||
value: String,
|
||||
unicode: Option<char>,
|
||||
},
|
||||
|
||||
/// A numeric value
|
||||
Number {
|
||||
value: String,
|
||||
/// Whether this is part of a decimal number
|
||||
is_decimal: bool,
|
||||
},
|
||||
|
||||
/// Binary operation (a op b)
|
||||
Binary {
|
||||
op: BinaryOp,
|
||||
left: Box<MathNode>,
|
||||
right: Box<MathNode>,
|
||||
},
|
||||
|
||||
/// Unary operation (op a)
|
||||
Unary { op: UnaryOp, operand: Box<MathNode> },
|
||||
|
||||
/// Fraction (numerator / denominator)
|
||||
Fraction {
|
||||
numerator: Box<MathNode>,
|
||||
denominator: Box<MathNode>,
|
||||
},
|
||||
|
||||
/// Radical (√, ∛, etc.)
|
||||
Radical {
|
||||
/// Index of the radical (2 for square root, 3 for cube root, etc.)
|
||||
index: Option<Box<MathNode>>,
|
||||
radicand: Box<MathNode>,
|
||||
},
|
||||
|
||||
/// Subscript or superscript
|
||||
Script {
|
||||
base: Box<MathNode>,
|
||||
subscript: Option<Box<MathNode>>,
|
||||
superscript: Option<Box<MathNode>>,
|
||||
},
|
||||
|
||||
/// Function application (sin, cos, log, etc.)
|
||||
Function {
|
||||
name: String,
|
||||
argument: Box<MathNode>,
|
||||
},
|
||||
|
||||
/// Matrix or vector
|
||||
Matrix {
|
||||
rows: Vec<Vec<MathNode>>,
|
||||
bracket_type: BracketType,
|
||||
},
|
||||
|
||||
/// Grouped expression with delimiters
|
||||
Group {
|
||||
content: Box<MathNode>,
|
||||
bracket_type: BracketType,
|
||||
},
|
||||
|
||||
/// Large operators (∑, ∫, ∏, etc.)
|
||||
LargeOp {
|
||||
op_type: LargeOpType,
|
||||
lower: Option<Box<MathNode>>,
|
||||
upper: Option<Box<MathNode>>,
|
||||
content: Box<MathNode>,
|
||||
},
|
||||
|
||||
/// Sequence of expressions (e.g., function arguments)
|
||||
Sequence { elements: Vec<MathNode> },
|
||||
|
||||
/// Text annotation in math mode
|
||||
Text { content: String },
|
||||
|
||||
/// Empty/placeholder node
|
||||
Empty,
|
||||
}
|
||||
|
||||
impl MathNode {
|
||||
/// Accept a visitor for tree traversal
|
||||
pub fn accept<V: MathVisitor>(&self, visitor: &mut V) {
|
||||
visitor.visit(self);
|
||||
match self {
|
||||
MathNode::Binary { left, right, .. } => {
|
||||
left.accept(visitor);
|
||||
right.accept(visitor);
|
||||
}
|
||||
MathNode::Unary { operand, .. } => {
|
||||
operand.accept(visitor);
|
||||
}
|
||||
MathNode::Fraction {
|
||||
numerator,
|
||||
denominator,
|
||||
} => {
|
||||
numerator.accept(visitor);
|
||||
denominator.accept(visitor);
|
||||
}
|
||||
MathNode::Radical { index, radicand } => {
|
||||
if let Some(idx) = index {
|
||||
idx.accept(visitor);
|
||||
}
|
||||
radicand.accept(visitor);
|
||||
}
|
||||
MathNode::Script {
|
||||
base,
|
||||
subscript,
|
||||
superscript,
|
||||
} => {
|
||||
base.accept(visitor);
|
||||
if let Some(sub) = subscript {
|
||||
sub.accept(visitor);
|
||||
}
|
||||
if let Some(sup) = superscript {
|
||||
sup.accept(visitor);
|
||||
}
|
||||
}
|
||||
MathNode::Function { argument, .. } => {
|
||||
argument.accept(visitor);
|
||||
}
|
||||
MathNode::Matrix { rows, .. } => {
|
||||
for row in rows {
|
||||
for elem in row {
|
||||
elem.accept(visitor);
|
||||
}
|
||||
}
|
||||
}
|
||||
MathNode::Group { content, .. } => {
|
||||
content.accept(visitor);
|
||||
}
|
||||
MathNode::LargeOp {
|
||||
lower,
|
||||
upper,
|
||||
content,
|
||||
..
|
||||
} => {
|
||||
if let Some(l) = lower {
|
||||
l.accept(visitor);
|
||||
}
|
||||
if let Some(u) = upper {
|
||||
u.accept(visitor);
|
||||
}
|
||||
content.accept(visitor);
|
||||
}
|
||||
MathNode::Sequence { elements } => {
|
||||
for elem in elements {
|
||||
elem.accept(visitor);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Binary operators
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum BinaryOp {
|
||||
Add,
|
||||
Subtract,
|
||||
Multiply,
|
||||
Divide,
|
||||
Power,
|
||||
Equal,
|
||||
NotEqual,
|
||||
Less,
|
||||
Greater,
|
||||
LessEqual,
|
||||
GreaterEqual,
|
||||
ApproxEqual,
|
||||
Equivalent,
|
||||
Similar,
|
||||
Congruent,
|
||||
Proportional,
|
||||
/// Custom operator with LaTeX representation
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
impl BinaryOp {
|
||||
/// Get precedence level (higher = binds tighter)
|
||||
pub fn precedence(&self) -> u8 {
|
||||
match self {
|
||||
BinaryOp::Power => 60,
|
||||
BinaryOp::Multiply | BinaryOp::Divide => 50,
|
||||
BinaryOp::Add | BinaryOp::Subtract => 40,
|
||||
BinaryOp::Equal
|
||||
| BinaryOp::NotEqual
|
||||
| BinaryOp::Less
|
||||
| BinaryOp::Greater
|
||||
| BinaryOp::LessEqual
|
||||
| BinaryOp::GreaterEqual
|
||||
| BinaryOp::ApproxEqual
|
||||
| BinaryOp::Equivalent
|
||||
| BinaryOp::Similar
|
||||
| BinaryOp::Congruent
|
||||
| BinaryOp::Proportional => 30,
|
||||
BinaryOp::Custom(_) => 35,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if operator is left-associative
|
||||
pub fn is_left_associative(&self) -> bool {
|
||||
!matches!(self, BinaryOp::Power)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for BinaryOp {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
BinaryOp::Add => write!(f, "+"),
|
||||
BinaryOp::Subtract => write!(f, "-"),
|
||||
BinaryOp::Multiply => write!(f, "×"),
|
||||
BinaryOp::Divide => write!(f, "÷"),
|
||||
BinaryOp::Power => write!(f, "^"),
|
||||
BinaryOp::Equal => write!(f, "="),
|
||||
BinaryOp::NotEqual => write!(f, "≠"),
|
||||
BinaryOp::Less => write!(f, "<"),
|
||||
BinaryOp::Greater => write!(f, ">"),
|
||||
BinaryOp::LessEqual => write!(f, "≤"),
|
||||
BinaryOp::GreaterEqual => write!(f, "≥"),
|
||||
BinaryOp::ApproxEqual => write!(f, "≈"),
|
||||
BinaryOp::Equivalent => write!(f, "≡"),
|
||||
BinaryOp::Similar => write!(f, "∼"),
|
||||
BinaryOp::Congruent => write!(f, "≅"),
|
||||
BinaryOp::Proportional => write!(f, "∝"),
|
||||
BinaryOp::Custom(s) => write!(f, "{}", s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Unary operators
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum UnaryOp {
|
||||
Plus,
|
||||
Minus,
|
||||
Not,
|
||||
/// Custom unary operator
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for UnaryOp {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
UnaryOp::Plus => write!(f, "+"),
|
||||
UnaryOp::Minus => write!(f, "-"),
|
||||
UnaryOp::Not => write!(f, "¬"),
|
||||
UnaryOp::Custom(s) => write!(f, "{}", s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Large operator types (∑, ∫, etc.)
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum LargeOpType {
|
||||
Sum, // ∑
|
||||
Product, // ∏
|
||||
Integral, // ∫
|
||||
DoubleIntegral, // ∬
|
||||
TripleIntegral, // ∭
|
||||
ContourIntegral, // ∮
|
||||
Union, // ⋃
|
||||
Intersection, // ⋂
|
||||
Coproduct, // ∐
|
||||
DirectSum, // ⊕
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for LargeOpType {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
LargeOpType::Sum => write!(f, "∑"),
|
||||
LargeOpType::Product => write!(f, "∏"),
|
||||
LargeOpType::Integral => write!(f, "∫"),
|
||||
LargeOpType::DoubleIntegral => write!(f, "∬"),
|
||||
LargeOpType::TripleIntegral => write!(f, "∭"),
|
||||
LargeOpType::ContourIntegral => write!(f, "∮"),
|
||||
LargeOpType::Union => write!(f, "⋃"),
|
||||
LargeOpType::Intersection => write!(f, "⋂"),
|
||||
LargeOpType::Coproduct => write!(f, "∐"),
|
||||
LargeOpType::DirectSum => write!(f, "⊕"),
|
||||
LargeOpType::Custom(s) => write!(f, "{}", s),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Bracket types for grouping and matrices
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum BracketType {
|
||||
Parentheses, // ( )
|
||||
Brackets, // [ ]
|
||||
Braces, // { }
|
||||
AngleBrackets, // ⟨ ⟩
|
||||
Vertical, // | |
|
||||
DoubleVertical, // ‖ ‖
|
||||
Floor, // ⌊ ⌋
|
||||
Ceiling, // ⌈ ⌉
|
||||
None, // No brackets
|
||||
}
|
||||
|
||||
impl BracketType {
|
||||
/// Get opening delimiter
|
||||
pub fn opening(&self) -> &str {
|
||||
match self {
|
||||
BracketType::Parentheses => "(",
|
||||
BracketType::Brackets => "[",
|
||||
BracketType::Braces => "{",
|
||||
BracketType::AngleBrackets => "⟨",
|
||||
BracketType::Vertical => "|",
|
||||
BracketType::DoubleVertical => "‖",
|
||||
BracketType::Floor => "⌊",
|
||||
BracketType::Ceiling => "⌈",
|
||||
BracketType::None => "",
|
||||
}
|
||||
}
|
||||
|
||||
/// Get closing delimiter
|
||||
pub fn closing(&self) -> &str {
|
||||
match self {
|
||||
BracketType::Parentheses => ")",
|
||||
BracketType::Brackets => "]",
|
||||
BracketType::Braces => "}",
|
||||
BracketType::AngleBrackets => "⟩",
|
||||
BracketType::Vertical => "|",
|
||||
BracketType::DoubleVertical => "‖",
|
||||
BracketType::Floor => "⌋",
|
||||
BracketType::Ceiling => "⌉",
|
||||
BracketType::None => "",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Visitor pattern for traversing the AST
|
||||
pub trait MathVisitor {
|
||||
fn visit(&mut self, node: &MathNode);
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_binary_op_precedence() {
|
||||
assert!(BinaryOp::Power.precedence() > BinaryOp::Multiply.precedence());
|
||||
assert!(BinaryOp::Multiply.precedence() > BinaryOp::Add.precedence());
|
||||
assert!(BinaryOp::Add.precedence() > BinaryOp::Equal.precedence());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_op_associativity() {
|
||||
assert!(BinaryOp::Add.is_left_associative());
|
||||
assert!(BinaryOp::Multiply.is_left_associative());
|
||||
assert!(!BinaryOp::Power.is_left_associative());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bracket_delimiters() {
|
||||
assert_eq!(BracketType::Parentheses.opening(), "(");
|
||||
assert_eq!(BracketType::Parentheses.closing(), ")");
|
||||
assert_eq!(BracketType::Brackets.opening(), "[");
|
||||
assert_eq!(BracketType::Braces.closing(), "}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_math_expr_creation() {
|
||||
let expr = MathExpr::new(
|
||||
MathNode::Number {
|
||||
value: "42".to_string(),
|
||||
is_decimal: false,
|
||||
},
|
||||
0.95,
|
||||
);
|
||||
assert_eq!(expr.confidence, 0.95);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_visitor_pattern() {
|
||||
struct CountVisitor {
|
||||
count: usize,
|
||||
}
|
||||
|
||||
impl MathVisitor for CountVisitor {
|
||||
fn visit(&mut self, _node: &MathNode) {
|
||||
self.count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let expr = MathExpr::new(
|
||||
MathNode::Binary {
|
||||
op: BinaryOp::Add,
|
||||
left: Box::new(MathNode::Number {
|
||||
value: "1".to_string(),
|
||||
is_decimal: false,
|
||||
}),
|
||||
right: Box::new(MathNode::Number {
|
||||
value: "2".to_string(),
|
||||
is_decimal: false,
|
||||
}),
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
|
||||
let mut visitor = CountVisitor { count: 0 };
|
||||
expr.accept(&mut visitor);
|
||||
assert_eq!(visitor.count, 3); // Binary + 2 numbers
|
||||
}
|
||||
}
|
||||
608
examples/scipix/src/math/latex.rs
Normal file
608
examples/scipix/src/math/latex.rs
Normal file
@@ -0,0 +1,608 @@
|
||||
//! LaTeX generation from mathematical AST
|
||||
//!
|
||||
//! This module converts mathematical AST nodes to LaTeX strings with proper
|
||||
//! formatting, precedence handling, and delimiter placement.
|
||||
|
||||
use crate::math::ast::{BinaryOp, BracketType, LargeOpType, MathExpr, MathNode, UnaryOp};
|
||||
use crate::math::symbols::unicode_to_latex;
|
||||
|
||||
/// Configuration for LaTeX generation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct LaTeXConfig {
|
||||
/// Use display style (true) or inline style (false)
|
||||
pub display_style: bool,
|
||||
/// Use \left and \right for delimiters
|
||||
pub auto_size_delimiters: bool,
|
||||
/// Insert spaces around operators
|
||||
pub spacing: bool,
|
||||
}
|
||||
|
||||
impl Default for LaTeXConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
display_style: false,
|
||||
auto_size_delimiters: true,
|
||||
spacing: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// LaTeX generator for mathematical expressions
|
||||
pub struct LaTeXGenerator {
|
||||
config: LaTeXConfig,
|
||||
}
|
||||
|
||||
impl LaTeXGenerator {
|
||||
/// Create a new LaTeX generator with default configuration
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
config: LaTeXConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new LaTeX generator with custom configuration
|
||||
pub fn with_config(config: LaTeXConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
/// Generate LaTeX string from a mathematical expression
|
||||
pub fn generate(&self, expr: &MathExpr) -> String {
|
||||
self.generate_node(&expr.root, None)
|
||||
}
|
||||
|
||||
/// Generate LaTeX for a single node
|
||||
fn generate_node(&self, node: &MathNode, parent_precedence: Option<u8>) -> String {
|
||||
match node {
|
||||
MathNode::Symbol { value, unicode } => {
|
||||
if let Some(c) = unicode {
|
||||
if let Some(latex) = unicode_to_latex(*c) {
|
||||
return format!("\\{}", latex);
|
||||
}
|
||||
}
|
||||
value.clone()
|
||||
}
|
||||
|
||||
MathNode::Number { value, .. } => value.clone(),
|
||||
|
||||
MathNode::Binary { op, left, right } => {
|
||||
let precedence = op.precedence();
|
||||
let needs_parens = parent_precedence.map_or(false, |p| precedence < p);
|
||||
|
||||
let left_str = self.generate_node(left, Some(precedence));
|
||||
let right_str = self.generate_node(
|
||||
right,
|
||||
Some(if op.is_left_associative() {
|
||||
precedence
|
||||
} else {
|
||||
precedence + 1
|
||||
}),
|
||||
);
|
||||
|
||||
let op_str = self.binary_op_to_latex(op);
|
||||
let space = if self.config.spacing { " " } else { "" };
|
||||
|
||||
let result = format!("{}{}{}{}{}", left_str, space, op_str, space, right_str);
|
||||
|
||||
if needs_parens {
|
||||
self.wrap_parens(&result)
|
||||
} else {
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
MathNode::Unary { op, operand } => {
|
||||
let op_str = self.unary_op_to_latex(op);
|
||||
let operand_str = self.generate_node(operand, Some(70)); // High precedence
|
||||
format!("{}{}", op_str, operand_str)
|
||||
}
|
||||
|
||||
MathNode::Fraction {
|
||||
numerator,
|
||||
denominator,
|
||||
} => {
|
||||
let num_str = self.generate_node(numerator, None);
|
||||
let den_str = self.generate_node(denominator, None);
|
||||
format!("\\frac{{{}}}{{{}}}", num_str, den_str)
|
||||
}
|
||||
|
||||
MathNode::Radical { index, radicand } => {
|
||||
let rad_str = self.generate_node(radicand, None);
|
||||
if let Some(idx) = index {
|
||||
let idx_str = self.generate_node(idx, None);
|
||||
format!("\\sqrt[{}]{{{}}}", idx_str, rad_str)
|
||||
} else {
|
||||
format!("\\sqrt{{{}}}", rad_str)
|
||||
}
|
||||
}
|
||||
|
||||
MathNode::Script {
|
||||
base,
|
||||
subscript,
|
||||
superscript,
|
||||
} => {
|
||||
let base_str = self.generate_node(base, Some(65));
|
||||
let mut result = base_str;
|
||||
|
||||
if let Some(sub) = subscript {
|
||||
let sub_str = self.generate_node(sub, None);
|
||||
result.push_str(&format!("_{{{}}}", sub_str));
|
||||
}
|
||||
|
||||
if let Some(sup) = superscript {
|
||||
let sup_str = self.generate_node(sup, None);
|
||||
result.push_str(&format!("^{{{}}}", sup_str));
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
MathNode::Function { name, argument } => {
|
||||
let arg_str = self.generate_node(argument, None);
|
||||
// Check if it's a standard function
|
||||
if is_standard_function(name) {
|
||||
format!("\\{} {}", name, arg_str)
|
||||
} else {
|
||||
format!("\\text{{{}}}({})", name, arg_str)
|
||||
}
|
||||
}
|
||||
|
||||
MathNode::Matrix { rows, bracket_type } => {
|
||||
let env = match bracket_type {
|
||||
BracketType::Parentheses => "pmatrix",
|
||||
BracketType::Brackets => "bmatrix",
|
||||
BracketType::Braces => "Bmatrix",
|
||||
BracketType::Vertical => "vmatrix",
|
||||
BracketType::DoubleVertical => "Vmatrix",
|
||||
_ => "matrix",
|
||||
};
|
||||
|
||||
let mut content = String::new();
|
||||
for (i, row) in rows.iter().enumerate() {
|
||||
if i > 0 {
|
||||
content.push_str(" \\\\ ");
|
||||
}
|
||||
for (j, elem) in row.iter().enumerate() {
|
||||
if j > 0 {
|
||||
content.push_str(" & ");
|
||||
}
|
||||
content.push_str(&self.generate_node(elem, None));
|
||||
}
|
||||
}
|
||||
|
||||
format!("\\begin{{{}}} {} \\end{{{}}}", env, content, env)
|
||||
}
|
||||
|
||||
MathNode::Group {
|
||||
content,
|
||||
bracket_type,
|
||||
} => {
|
||||
let content_str = self.generate_node(content, None);
|
||||
self.wrap_with_brackets(&content_str, *bracket_type)
|
||||
}
|
||||
|
||||
MathNode::LargeOp {
|
||||
op_type,
|
||||
lower,
|
||||
upper,
|
||||
content,
|
||||
} => {
|
||||
let op_str = self.large_op_to_latex(op_type);
|
||||
let content_str = self.generate_node(content, None);
|
||||
|
||||
let mut result = op_str;
|
||||
|
||||
if let Some(low) = lower {
|
||||
let low_str = self.generate_node(low, None);
|
||||
result.push_str(&format!("_{{{}}}", low_str));
|
||||
}
|
||||
|
||||
if let Some(up) = upper {
|
||||
let up_str = self.generate_node(up, None);
|
||||
result.push_str(&format!("^{{{}}}", up_str));
|
||||
}
|
||||
|
||||
format!("{} {}", result, content_str)
|
||||
}
|
||||
|
||||
MathNode::Sequence { elements } => elements
|
||||
.iter()
|
||||
.map(|e| self.generate_node(e, None))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", "),
|
||||
|
||||
MathNode::Text { content } => {
|
||||
format!("\\text{{{}}}", content)
|
||||
}
|
||||
|
||||
MathNode::Empty => String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert binary operator to LaTeX
|
||||
fn binary_op_to_latex(&self, op: &BinaryOp) -> String {
|
||||
match op {
|
||||
BinaryOp::Add => "+".to_string(),
|
||||
BinaryOp::Subtract => "-".to_string(),
|
||||
BinaryOp::Multiply => "\\times".to_string(),
|
||||
BinaryOp::Divide => "\\div".to_string(),
|
||||
BinaryOp::Power => "^".to_string(),
|
||||
BinaryOp::Equal => "=".to_string(),
|
||||
BinaryOp::NotEqual => "\\neq".to_string(),
|
||||
BinaryOp::Less => "<".to_string(),
|
||||
BinaryOp::Greater => ">".to_string(),
|
||||
BinaryOp::LessEqual => "\\leq".to_string(),
|
||||
BinaryOp::GreaterEqual => "\\geq".to_string(),
|
||||
BinaryOp::ApproxEqual => "\\approx".to_string(),
|
||||
BinaryOp::Equivalent => "\\equiv".to_string(),
|
||||
BinaryOp::Similar => "\\sim".to_string(),
|
||||
BinaryOp::Congruent => "\\cong".to_string(),
|
||||
BinaryOp::Proportional => "\\propto".to_string(),
|
||||
BinaryOp::Custom(s) => s.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert unary operator to LaTeX
|
||||
fn unary_op_to_latex(&self, op: &UnaryOp) -> String {
|
||||
match op {
|
||||
UnaryOp::Plus => "+".to_string(),
|
||||
UnaryOp::Minus => "-".to_string(),
|
||||
UnaryOp::Not => "\\neg".to_string(),
|
||||
UnaryOp::Custom(s) => s.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert large operator to LaTeX
|
||||
fn large_op_to_latex(&self, op: &LargeOpType) -> String {
|
||||
match op {
|
||||
LargeOpType::Sum => "\\sum".to_string(),
|
||||
LargeOpType::Product => "\\prod".to_string(),
|
||||
LargeOpType::Integral => "\\int".to_string(),
|
||||
LargeOpType::DoubleIntegral => "\\iint".to_string(),
|
||||
LargeOpType::TripleIntegral => "\\iiint".to_string(),
|
||||
LargeOpType::ContourIntegral => "\\oint".to_string(),
|
||||
LargeOpType::Union => "\\bigcup".to_string(),
|
||||
LargeOpType::Intersection => "\\bigcap".to_string(),
|
||||
LargeOpType::Coproduct => "\\coprod".to_string(),
|
||||
LargeOpType::DirectSum => "\\bigoplus".to_string(),
|
||||
LargeOpType::Custom(s) => s.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrap content with brackets
|
||||
fn wrap_with_brackets(&self, content: &str, bracket_type: BracketType) -> String {
|
||||
let (left, right) = if self.config.auto_size_delimiters {
|
||||
match bracket_type {
|
||||
BracketType::Parentheses => ("\\left(", "\\right)"),
|
||||
BracketType::Brackets => ("\\left[", "\\right]"),
|
||||
BracketType::Braces => ("\\left\\{", "\\right\\}"),
|
||||
BracketType::AngleBrackets => ("\\left\\langle", "\\right\\rangle"),
|
||||
BracketType::Vertical => ("\\left|", "\\right|"),
|
||||
BracketType::DoubleVertical => ("\\left\\|", "\\right\\|"),
|
||||
BracketType::Floor => ("\\left\\lfloor", "\\right\\rfloor"),
|
||||
BracketType::Ceiling => ("\\left\\lceil", "\\right\\rceil"),
|
||||
BracketType::None => ("", ""),
|
||||
}
|
||||
} else {
|
||||
match bracket_type {
|
||||
BracketType::Parentheses => ("(", ")"),
|
||||
BracketType::Brackets => ("[", "]"),
|
||||
BracketType::Braces => ("\\{", "\\}"),
|
||||
BracketType::AngleBrackets => ("\\langle", "\\rangle"),
|
||||
BracketType::Vertical => ("|", "|"),
|
||||
BracketType::DoubleVertical => ("\\|", "\\|"),
|
||||
BracketType::Floor => ("\\lfloor", "\\rfloor"),
|
||||
BracketType::Ceiling => ("\\lceil", "\\rceil"),
|
||||
BracketType::None => ("", ""),
|
||||
}
|
||||
};
|
||||
|
||||
format!("{}{}{}", left, content, right)
|
||||
}
|
||||
|
||||
/// Wrap content in parentheses
|
||||
fn wrap_parens(&self, content: &str) -> String {
|
||||
self.wrap_with_brackets(content, BracketType::Parentheses)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LaTeXGenerator {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a function name is a standard LaTeX function
|
||||
fn is_standard_function(name: &str) -> bool {
|
||||
matches!(
|
||||
name,
|
||||
"sin"
|
||||
| "cos"
|
||||
| "tan"
|
||||
| "cot"
|
||||
| "sec"
|
||||
| "csc"
|
||||
| "sinh"
|
||||
| "cosh"
|
||||
| "tanh"
|
||||
| "coth"
|
||||
| "arcsin"
|
||||
| "arccos"
|
||||
| "arctan"
|
||||
| "ln"
|
||||
| "log"
|
||||
| "exp"
|
||||
| "lim"
|
||||
| "sup"
|
||||
| "inf"
|
||||
| "max"
|
||||
| "min"
|
||||
| "det"
|
||||
| "dim"
|
||||
| "ker"
|
||||
| "deg"
|
||||
| "gcd"
|
||||
| "lcm"
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_simple_number() {
|
||||
let expr = MathExpr::new(
|
||||
MathNode::Number {
|
||||
value: "42".to_string(),
|
||||
is_decimal: false,
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
let gen = LaTeXGenerator::new();
|
||||
assert_eq!(gen.generate(&expr), "42");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_binary() {
|
||||
let expr = MathExpr::new(
|
||||
MathNode::Binary {
|
||||
op: BinaryOp::Add,
|
||||
left: Box::new(MathNode::Number {
|
||||
value: "1".to_string(),
|
||||
is_decimal: false,
|
||||
}),
|
||||
right: Box::new(MathNode::Number {
|
||||
value: "2".to_string(),
|
||||
is_decimal: false,
|
||||
}),
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
let gen = LaTeXGenerator::new();
|
||||
assert_eq!(gen.generate(&expr), "1 + 2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fraction() {
|
||||
let expr = MathExpr::new(
|
||||
MathNode::Fraction {
|
||||
numerator: Box::new(MathNode::Number {
|
||||
value: "1".to_string(),
|
||||
is_decimal: false,
|
||||
}),
|
||||
denominator: Box::new(MathNode::Number {
|
||||
value: "2".to_string(),
|
||||
is_decimal: false,
|
||||
}),
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
let gen = LaTeXGenerator::new();
|
||||
assert_eq!(gen.generate(&expr), "\\frac{1}{2}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_square_root() {
|
||||
let expr = MathExpr::new(
|
||||
MathNode::Radical {
|
||||
index: None,
|
||||
radicand: Box::new(MathNode::Number {
|
||||
value: "2".to_string(),
|
||||
is_decimal: false,
|
||||
}),
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
let gen = LaTeXGenerator::new();
|
||||
assert_eq!(gen.generate(&expr), "\\sqrt{2}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nth_root() {
|
||||
let expr = MathExpr::new(
|
||||
MathNode::Radical {
|
||||
index: Some(Box::new(MathNode::Number {
|
||||
value: "3".to_string(),
|
||||
is_decimal: false,
|
||||
})),
|
||||
radicand: Box::new(MathNode::Number {
|
||||
value: "8".to_string(),
|
||||
is_decimal: false,
|
||||
}),
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
let gen = LaTeXGenerator::new();
|
||||
assert_eq!(gen.generate(&expr), "\\sqrt[3]{8}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_superscript() {
|
||||
let expr = MathExpr::new(
|
||||
MathNode::Script {
|
||||
base: Box::new(MathNode::Symbol {
|
||||
value: "x".to_string(),
|
||||
unicode: None,
|
||||
}),
|
||||
subscript: None,
|
||||
superscript: Some(Box::new(MathNode::Number {
|
||||
value: "2".to_string(),
|
||||
is_decimal: false,
|
||||
})),
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
let gen = LaTeXGenerator::new();
|
||||
assert_eq!(gen.generate(&expr), "x^{2}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_subscript() {
|
||||
let expr = MathExpr::new(
|
||||
MathNode::Script {
|
||||
base: Box::new(MathNode::Symbol {
|
||||
value: "a".to_string(),
|
||||
unicode: None,
|
||||
}),
|
||||
subscript: Some(Box::new(MathNode::Number {
|
||||
value: "n".to_string(),
|
||||
is_decimal: false,
|
||||
})),
|
||||
superscript: None,
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
let gen = LaTeXGenerator::new();
|
||||
assert_eq!(gen.generate(&expr), "a_{n}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_complex_fraction() {
|
||||
// (a + b) / (c - d)
|
||||
let expr = MathExpr::new(
|
||||
MathNode::Fraction {
|
||||
numerator: Box::new(MathNode::Binary {
|
||||
op: BinaryOp::Add,
|
||||
left: Box::new(MathNode::Symbol {
|
||||
value: "a".to_string(),
|
||||
unicode: None,
|
||||
}),
|
||||
right: Box::new(MathNode::Symbol {
|
||||
value: "b".to_string(),
|
||||
unicode: None,
|
||||
}),
|
||||
}),
|
||||
denominator: Box::new(MathNode::Binary {
|
||||
op: BinaryOp::Subtract,
|
||||
left: Box::new(MathNode::Symbol {
|
||||
value: "c".to_string(),
|
||||
unicode: None,
|
||||
}),
|
||||
right: Box::new(MathNode::Symbol {
|
||||
value: "d".to_string(),
|
||||
unicode: None,
|
||||
}),
|
||||
}),
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
let gen = LaTeXGenerator::new();
|
||||
assert_eq!(gen.generate(&expr), "\\frac{a + b}{c - d}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_summation() {
|
||||
// ∑_{i=1}^{n} i
|
||||
let expr = MathExpr::new(
|
||||
MathNode::LargeOp {
|
||||
op_type: LargeOpType::Sum,
|
||||
lower: Some(Box::new(MathNode::Binary {
|
||||
op: BinaryOp::Equal,
|
||||
left: Box::new(MathNode::Symbol {
|
||||
value: "i".to_string(),
|
||||
unicode: None,
|
||||
}),
|
||||
right: Box::new(MathNode::Number {
|
||||
value: "1".to_string(),
|
||||
is_decimal: false,
|
||||
}),
|
||||
})),
|
||||
upper: Some(Box::new(MathNode::Symbol {
|
||||
value: "n".to_string(),
|
||||
unicode: None,
|
||||
})),
|
||||
content: Box::new(MathNode::Symbol {
|
||||
value: "i".to_string(),
|
||||
unicode: None,
|
||||
}),
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
let gen = LaTeXGenerator::new();
|
||||
assert_eq!(gen.generate(&expr), "\\sum_{i = 1}^{n} i");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_integral() {
|
||||
// ∫ x dx
|
||||
let expr = MathExpr::new(
|
||||
MathNode::LargeOp {
|
||||
op_type: LargeOpType::Integral,
|
||||
lower: None,
|
||||
upper: None,
|
||||
content: Box::new(MathNode::Sequence {
|
||||
elements: vec![
|
||||
MathNode::Symbol {
|
||||
value: "x".to_string(),
|
||||
unicode: None,
|
||||
},
|
||||
MathNode::Symbol {
|
||||
value: "dx".to_string(),
|
||||
unicode: None,
|
||||
},
|
||||
],
|
||||
}),
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
let gen = LaTeXGenerator::new();
|
||||
assert_eq!(gen.generate(&expr), "\\int x, dx");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matrix() {
|
||||
let expr = MathExpr::new(
|
||||
MathNode::Matrix {
|
||||
rows: vec![
|
||||
vec![
|
||||
MathNode::Number {
|
||||
value: "1".to_string(),
|
||||
is_decimal: false,
|
||||
},
|
||||
MathNode::Number {
|
||||
value: "2".to_string(),
|
||||
is_decimal: false,
|
||||
},
|
||||
],
|
||||
vec![
|
||||
MathNode::Number {
|
||||
value: "3".to_string(),
|
||||
is_decimal: false,
|
||||
},
|
||||
MathNode::Number {
|
||||
value: "4".to_string(),
|
||||
is_decimal: false,
|
||||
},
|
||||
],
|
||||
],
|
||||
bracket_type: BracketType::Brackets,
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
let gen = LaTeXGenerator::new();
|
||||
assert_eq!(
|
||||
gen.generate(&expr),
|
||||
"\\begin{bmatrix} 1 & 2 \\\\ 3 & 4 \\end{bmatrix}"
|
||||
);
|
||||
}
|
||||
}
|
||||
408
examples/scipix/src/math/mathml.rs
Normal file
408
examples/scipix/src/math/mathml.rs
Normal file
@@ -0,0 +1,408 @@
|
||||
//! MathML generation from mathematical AST
|
||||
//!
|
||||
//! This module converts mathematical AST nodes to MathML (Mathematical Markup Language)
|
||||
//! XML format for rendering in web browsers and applications.
|
||||
|
||||
use crate::math::ast::{BinaryOp, BracketType, LargeOpType, MathExpr, MathNode, UnaryOp};
|
||||
|
||||
/// MathML generator for mathematical expressions
|
||||
pub struct MathMLGenerator {
|
||||
/// Use presentation MathML (true) or content MathML (false)
|
||||
presentation: bool,
|
||||
}
|
||||
|
||||
impl MathMLGenerator {
|
||||
/// Create a new MathML generator (presentation mode)
|
||||
pub fn new() -> Self {
|
||||
Self { presentation: true }
|
||||
}
|
||||
|
||||
/// Create a content MathML generator
|
||||
pub fn content() -> Self {
|
||||
Self {
|
||||
presentation: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate MathML string from a mathematical expression
|
||||
pub fn generate(&self, expr: &MathExpr) -> String {
|
||||
let content = self.generate_node(&expr.root);
|
||||
format!(
|
||||
r#"<math xmlns="http://www.w3.org/1998/Math/MathML">{}</math>"#,
|
||||
content
|
||||
)
|
||||
}
|
||||
|
||||
/// Generate MathML for a single node
|
||||
fn generate_node(&self, node: &MathNode) -> String {
|
||||
match node {
|
||||
MathNode::Symbol { value, .. } => {
|
||||
format!("<mi>{}</mi>", escape_xml(value))
|
||||
}
|
||||
|
||||
MathNode::Number { value, .. } => {
|
||||
format!("<mn>{}</mn>", escape_xml(value))
|
||||
}
|
||||
|
||||
MathNode::Binary { op, left, right } => {
|
||||
let left_ml = self.generate_node(left);
|
||||
let right_ml = self.generate_node(right);
|
||||
let op_ml = self.binary_op_to_mathml(op);
|
||||
|
||||
format!("<mrow>{}<mo>{}</mo>{}</mrow>", left_ml, op_ml, right_ml)
|
||||
}
|
||||
|
||||
MathNode::Unary { op, operand } => {
|
||||
let op_ml = self.unary_op_to_mathml(op);
|
||||
let operand_ml = self.generate_node(operand);
|
||||
|
||||
format!("<mrow><mo>{}</mo>{}</mrow>", op_ml, operand_ml)
|
||||
}
|
||||
|
||||
MathNode::Fraction {
|
||||
numerator,
|
||||
denominator,
|
||||
} => {
|
||||
let num_ml = self.generate_node(numerator);
|
||||
let den_ml = self.generate_node(denominator);
|
||||
|
||||
format!("<mfrac>{}{}</mfrac>", num_ml, den_ml)
|
||||
}
|
||||
|
||||
MathNode::Radical { index, radicand } => {
|
||||
let rad_ml = self.generate_node(radicand);
|
||||
|
||||
if let Some(idx) = index {
|
||||
let idx_ml = self.generate_node(idx);
|
||||
format!("<mroot>{}{}</mroot>", rad_ml, idx_ml)
|
||||
} else {
|
||||
format!("<msqrt>{}</msqrt>", rad_ml)
|
||||
}
|
||||
}
|
||||
|
||||
MathNode::Script {
|
||||
base,
|
||||
subscript,
|
||||
superscript,
|
||||
} => {
|
||||
let base_ml = self.generate_node(base);
|
||||
|
||||
match (subscript, superscript) {
|
||||
(Some(sub), Some(sup)) => {
|
||||
let sub_ml = self.generate_node(sub);
|
||||
let sup_ml = self.generate_node(sup);
|
||||
format!("<msubsup>{}{}{}</msubsup>", base_ml, sub_ml, sup_ml)
|
||||
}
|
||||
(Some(sub), None) => {
|
||||
let sub_ml = self.generate_node(sub);
|
||||
format!("<msub>{}{}</msub>", base_ml, sub_ml)
|
||||
}
|
||||
(None, Some(sup)) => {
|
||||
let sup_ml = self.generate_node(sup);
|
||||
format!("<msup>{}{}</msup>", base_ml, sup_ml)
|
||||
}
|
||||
(None, None) => base_ml,
|
||||
}
|
||||
}
|
||||
|
||||
MathNode::Function { name, argument } => {
|
||||
let name_ml = format!("<mi>{}</mi>", escape_xml(name));
|
||||
let arg_ml = self.generate_node(argument);
|
||||
|
||||
format!("<mrow>{}<mo>⁡</mo>{}</mrow>", name_ml, arg_ml)
|
||||
}
|
||||
|
||||
MathNode::Matrix { rows, bracket_type } => {
|
||||
let mut content = String::new();
|
||||
|
||||
for row in rows {
|
||||
content.push_str("<mtr>");
|
||||
for elem in row {
|
||||
content.push_str("<mtd>");
|
||||
content.push_str(&self.generate_node(elem));
|
||||
content.push_str("</mtd>");
|
||||
}
|
||||
content.push_str("</mtr>");
|
||||
}
|
||||
|
||||
let (open, close) = self.bracket_to_mathml(*bracket_type);
|
||||
|
||||
format!(
|
||||
"<mrow><mo>{}</mo><mtable>{}</mtable><mo>{}</mo></mrow>",
|
||||
open, content, close
|
||||
)
|
||||
}
|
||||
|
||||
MathNode::Group {
|
||||
content,
|
||||
bracket_type,
|
||||
} => {
|
||||
let content_ml = self.generate_node(content);
|
||||
let (open, close) = self.bracket_to_mathml(*bracket_type);
|
||||
|
||||
if *bracket_type == BracketType::None {
|
||||
content_ml
|
||||
} else {
|
||||
format!(
|
||||
"<mrow><mo>{}</mo>{}<mo>{}</mo></mrow>",
|
||||
open, content_ml, close
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
MathNode::LargeOp {
|
||||
op_type,
|
||||
lower,
|
||||
upper,
|
||||
content,
|
||||
} => {
|
||||
let op_ml = self.large_op_to_mathml(op_type);
|
||||
let content_ml = self.generate_node(content);
|
||||
|
||||
match (lower, upper) {
|
||||
(Some(low), Some(up)) => {
|
||||
let low_ml = self.generate_node(low);
|
||||
let up_ml = self.generate_node(up);
|
||||
format!(
|
||||
"<mrow><munderover><mo>{}</mo>{}{}</munderover>{}</mrow>",
|
||||
op_ml, low_ml, up_ml, content_ml
|
||||
)
|
||||
}
|
||||
(Some(low), None) => {
|
||||
let low_ml = self.generate_node(low);
|
||||
format!(
|
||||
"<mrow><munder><mo>{}</mo>{}</munder>{}</mrow>",
|
||||
op_ml, low_ml, content_ml
|
||||
)
|
||||
}
|
||||
(None, Some(up)) => {
|
||||
let up_ml = self.generate_node(up);
|
||||
format!(
|
||||
"<mrow><mover><mo>{}</mo>{}</mover>{}</mrow>",
|
||||
op_ml, up_ml, content_ml
|
||||
)
|
||||
}
|
||||
(None, None) => {
|
||||
format!("<mrow><mo>{}</mo>{}</mrow>", op_ml, content_ml)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
MathNode::Sequence { elements } => {
|
||||
let mut content = String::new();
|
||||
for (i, elem) in elements.iter().enumerate() {
|
||||
if i > 0 {
|
||||
content.push_str("<mo>,</mo>");
|
||||
}
|
||||
content.push_str(&self.generate_node(elem));
|
||||
}
|
||||
format!("<mrow>{}</mrow>", content)
|
||||
}
|
||||
|
||||
MathNode::Text { content } => {
|
||||
format!("<mtext>{}</mtext>", escape_xml(content))
|
||||
}
|
||||
|
||||
MathNode::Empty => String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert binary operator to MathML
|
||||
fn binary_op_to_mathml(&self, op: &BinaryOp) -> String {
|
||||
match op {
|
||||
BinaryOp::Add => "+".to_string(),
|
||||
BinaryOp::Subtract => "−".to_string(),
|
||||
BinaryOp::Multiply => "×".to_string(),
|
||||
BinaryOp::Divide => "÷".to_string(),
|
||||
BinaryOp::Power => "^".to_string(),
|
||||
BinaryOp::Equal => "=".to_string(),
|
||||
BinaryOp::NotEqual => "≠".to_string(),
|
||||
BinaryOp::Less => "<".to_string(),
|
||||
BinaryOp::Greater => ">".to_string(),
|
||||
BinaryOp::LessEqual => "≤".to_string(),
|
||||
BinaryOp::GreaterEqual => "≥".to_string(),
|
||||
BinaryOp::ApproxEqual => "≈".to_string(),
|
||||
BinaryOp::Equivalent => "≡".to_string(),
|
||||
BinaryOp::Similar => "∼".to_string(),
|
||||
BinaryOp::Congruent => "≅".to_string(),
|
||||
BinaryOp::Proportional => "∝".to_string(),
|
||||
BinaryOp::Custom(s) => s.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert unary operator to MathML
|
||||
fn unary_op_to_mathml(&self, op: &UnaryOp) -> String {
|
||||
match op {
|
||||
UnaryOp::Plus => "+".to_string(),
|
||||
UnaryOp::Minus => "−".to_string(),
|
||||
UnaryOp::Not => "¬".to_string(),
|
||||
UnaryOp::Custom(s) => s.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert large operator to MathML
|
||||
fn large_op_to_mathml(&self, op: &LargeOpType) -> &'static str {
|
||||
match op {
|
||||
LargeOpType::Sum => "∑",
|
||||
LargeOpType::Product => "∏",
|
||||
LargeOpType::Integral => "∫",
|
||||
LargeOpType::DoubleIntegral => "∬",
|
||||
LargeOpType::TripleIntegral => "∭",
|
||||
LargeOpType::ContourIntegral => "∮",
|
||||
LargeOpType::Union => "⋃",
|
||||
LargeOpType::Intersection => "⋂",
|
||||
LargeOpType::Coproduct => "∐",
|
||||
LargeOpType::DirectSum => "⊕",
|
||||
LargeOpType::Custom(_) => "∑", // Default fallback
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert bracket type to MathML delimiters
|
||||
fn bracket_to_mathml(&self, bracket_type: BracketType) -> (&'static str, &'static str) {
|
||||
match bracket_type {
|
||||
BracketType::Parentheses => ("(", ")"),
|
||||
BracketType::Brackets => ("[", "]"),
|
||||
BracketType::Braces => ("{", "}"),
|
||||
BracketType::AngleBrackets => ("⟨", "⟩"),
|
||||
BracketType::Vertical => ("|", "|"),
|
||||
BracketType::DoubleVertical => ("‖", "‖"),
|
||||
BracketType::Floor => ("⌊", "⌋"),
|
||||
BracketType::Ceiling => ("⌈", "⌉"),
|
||||
BracketType::None => ("", ""),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MathMLGenerator {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Escape XML special characters
|
||||
fn escape_xml(s: &str) -> String {
|
||||
s.replace('&', "&")
|
||||
.replace('<', "<")
|
||||
.replace('>', ">")
|
||||
.replace('"', """)
|
||||
.replace('\'', "'")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_number() {
|
||||
let expr = MathExpr::new(
|
||||
MathNode::Number {
|
||||
value: "42".to_string(),
|
||||
is_decimal: false,
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
let gen = MathMLGenerator::new();
|
||||
let result = gen.generate(&expr);
|
||||
assert!(result.contains("<mn>42</mn>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_symbol() {
|
||||
let expr = MathExpr::new(
|
||||
MathNode::Symbol {
|
||||
value: "x".to_string(),
|
||||
unicode: Some('x'),
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
let gen = MathMLGenerator::new();
|
||||
let result = gen.generate(&expr);
|
||||
assert!(result.contains("<mi>x</mi>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_binary_add() {
|
||||
let expr = MathExpr::new(
|
||||
MathNode::Binary {
|
||||
op: BinaryOp::Add,
|
||||
left: Box::new(MathNode::Number {
|
||||
value: "1".to_string(),
|
||||
is_decimal: false,
|
||||
}),
|
||||
right: Box::new(MathNode::Number {
|
||||
value: "2".to_string(),
|
||||
is_decimal: false,
|
||||
}),
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
let gen = MathMLGenerator::new();
|
||||
let result = gen.generate(&expr);
|
||||
assert!(result.contains("<mrow>"));
|
||||
assert!(result.contains("<mo>+</mo>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fraction() {
|
||||
let expr = MathExpr::new(
|
||||
MathNode::Fraction {
|
||||
numerator: Box::new(MathNode::Number {
|
||||
value: "1".to_string(),
|
||||
is_decimal: false,
|
||||
}),
|
||||
denominator: Box::new(MathNode::Number {
|
||||
value: "2".to_string(),
|
||||
is_decimal: false,
|
||||
}),
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
let gen = MathMLGenerator::new();
|
||||
let result = gen.generate(&expr);
|
||||
assert!(result.contains("<mfrac>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sqrt() {
|
||||
let expr = MathExpr::new(
|
||||
MathNode::Radical {
|
||||
index: None,
|
||||
radicand: Box::new(MathNode::Number {
|
||||
value: "2".to_string(),
|
||||
is_decimal: false,
|
||||
}),
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
let gen = MathMLGenerator::new();
|
||||
let result = gen.generate(&expr);
|
||||
assert!(result.contains("<msqrt>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_superscript() {
|
||||
let expr = MathExpr::new(
|
||||
MathNode::Script {
|
||||
base: Box::new(MathNode::Symbol {
|
||||
value: "x".to_string(),
|
||||
unicode: Some('x'),
|
||||
}),
|
||||
subscript: None,
|
||||
superscript: Some(Box::new(MathNode::Number {
|
||||
value: "2".to_string(),
|
||||
is_decimal: false,
|
||||
})),
|
||||
},
|
||||
1.0,
|
||||
);
|
||||
let gen = MathMLGenerator::new();
|
||||
let result = gen.generate(&expr);
|
||||
assert!(result.contains("<msup>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_xml_escaping() {
|
||||
assert_eq!(escape_xml("a < b"), "a < b");
|
||||
assert_eq!(escape_xml("x & y"), "x & y");
|
||||
}
|
||||
}
|
||||
246
examples/scipix/src/math/mod.rs
Normal file
246
examples/scipix/src/math/mod.rs
Normal file
@@ -0,0 +1,246 @@
|
||||
//! Mathematical expression parsing and conversion module
|
||||
//!
|
||||
//! This module provides functionality for parsing, representing, and converting
|
||||
//! mathematical expressions between various formats including LaTeX, MathML, and AsciiMath.
|
||||
//!
|
||||
//! # Modules
|
||||
//!
|
||||
//! - `ast`: Abstract Syntax Tree definitions for mathematical expressions
|
||||
//! - `symbols`: Symbol mappings between Unicode and LaTeX
|
||||
//! - `latex`: LaTeX generation from AST
|
||||
//! - `mathml`: MathML generation from AST
|
||||
//! - `asciimath`: AsciiMath generation from AST
|
||||
//! - `parser`: Expression parsing from various formats
|
||||
//!
|
||||
//! # Examples
|
||||
//!
|
||||
//! ## Parsing and converting to LaTeX
|
||||
//!
|
||||
//! ```no_run
|
||||
//! use ruvector_scipix::math::{parse_expression, to_latex};
|
||||
//!
|
||||
//! let expr = parse_expression("x^2 + 2x + 1").unwrap();
|
||||
//! let latex = to_latex(&expr);
|
||||
//! println!("LaTeX: {}", latex);
|
||||
//! ```
|
||||
//!
|
||||
//! ## Building an expression manually
|
||||
//!
|
||||
//! ```no_run
|
||||
//! use ruvector_scipix::math::ast::{MathExpr, MathNode, BinaryOp};
|
||||
//!
|
||||
//! let expr = MathExpr::new(
|
||||
//! MathNode::Binary {
|
||||
//! op: BinaryOp::Add,
|
||||
//! left: Box::new(MathNode::Number {
|
||||
//! value: "1".to_string(),
|
||||
//! is_decimal: false,
|
||||
//! }),
|
||||
//! right: Box::new(MathNode::Number {
|
||||
//! value: "2".to_string(),
|
||||
//! is_decimal: false,
|
||||
//! }),
|
||||
//! },
|
||||
//! 1.0,
|
||||
//! );
|
||||
//! ```
|
||||
|
||||
pub mod asciimath;
|
||||
pub mod ast;
|
||||
pub mod latex;
|
||||
pub mod mathml;
|
||||
pub mod parser;
|
||||
pub mod symbols;
|
||||
|
||||
// Re-export commonly used types
|
||||
pub use asciimath::AsciiMathGenerator;
|
||||
pub use ast::{BinaryOp, BracketType, LargeOpType, MathExpr, MathNode, MathVisitor, UnaryOp};
|
||||
pub use latex::{LaTeXConfig, LaTeXGenerator};
|
||||
pub use mathml::MathMLGenerator;
|
||||
pub use parser::{parse_expression, Parser};
|
||||
pub use symbols::{get_symbol, unicode_to_latex, MathSymbol, SymbolCategory};
|
||||
|
||||
/// Parse a mathematical expression from a string
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `input` - The input string to parse (LaTeX, Unicode, or mixed)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A `Result` containing the parsed `MathExpr` or an error message
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```no_run
|
||||
/// use ruvector_scipix::math::parse_expression;
|
||||
///
|
||||
/// let expr = parse_expression("\\frac{1}{2}").unwrap();
|
||||
/// ```
|
||||
pub fn parse(input: &str) -> Result<MathExpr, String> {
|
||||
parse_expression(input)
|
||||
}
|
||||
|
||||
/// Convert a mathematical expression to LaTeX format
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `expr` - The mathematical expression to convert
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A LaTeX string representation of the expression
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```no_run
|
||||
/// use ruvector_scipix::math::{parse_expression, to_latex};
|
||||
///
|
||||
/// let expr = parse_expression("x^2").unwrap();
|
||||
/// let latex = to_latex(&expr);
|
||||
/// assert!(latex.contains("^"));
|
||||
/// ```
|
||||
pub fn to_latex(expr: &MathExpr) -> String {
|
||||
LaTeXGenerator::new().generate(expr)
|
||||
}
|
||||
|
||||
/// Convert a mathematical expression to LaTeX with custom configuration
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `expr` - The mathematical expression to convert
|
||||
/// * `config` - LaTeX generation configuration
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A LaTeX string representation of the expression
|
||||
pub fn to_latex_with_config(expr: &MathExpr, config: LaTeXConfig) -> String {
|
||||
LaTeXGenerator::with_config(config).generate(expr)
|
||||
}
|
||||
|
||||
/// Convert a mathematical expression to MathML format
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `expr` - The mathematical expression to convert
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A MathML XML string representation of the expression
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```no_run
|
||||
/// use ruvector_scipix::math::{parse_expression, to_mathml};
|
||||
///
|
||||
/// let expr = parse_expression("x^2").unwrap();
|
||||
/// let mathml = to_mathml(&expr);
|
||||
/// assert!(mathml.contains("<msup>"));
|
||||
/// ```
|
||||
pub fn to_mathml(expr: &MathExpr) -> String {
|
||||
MathMLGenerator::new().generate(expr)
|
||||
}
|
||||
|
||||
/// Convert a mathematical expression to AsciiMath format
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `expr` - The mathematical expression to convert
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An AsciiMath string representation of the expression
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```no_run
|
||||
/// use ruvector_scipix::math::{parse_expression, to_asciimath};
|
||||
///
|
||||
/// let expr = parse_expression("x^2").unwrap();
|
||||
/// let asciimath = to_asciimath(&expr);
|
||||
/// ```
|
||||
pub fn to_asciimath(expr: &MathExpr) -> String {
|
||||
AsciiMathGenerator::new().generate(expr)
|
||||
}
|
||||
|
||||
/// Convert a mathematical expression to ASCII-only AsciiMath format
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `expr` - The mathematical expression to convert
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An ASCII-only AsciiMath string representation of the expression
|
||||
pub fn to_asciimath_ascii_only(expr: &MathExpr) -> String {
|
||||
AsciiMathGenerator::ascii_only().generate(expr)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_and_convert() {
|
||||
let expr = parse("1 + 2").unwrap();
|
||||
let latex = to_latex(&expr);
|
||||
assert!(latex.contains("+"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fraction_conversion() {
|
||||
let expr = parse("\\frac{1}{2}").unwrap();
|
||||
|
||||
let latex = to_latex(&expr);
|
||||
assert!(latex.contains("\\frac"));
|
||||
|
||||
let mathml = to_mathml(&expr);
|
||||
assert!(mathml.contains("<mfrac>"));
|
||||
|
||||
let asciimath = to_asciimath(&expr);
|
||||
assert!(asciimath.contains("/"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sqrt_conversion() {
|
||||
let expr = parse("\\sqrt{2}").unwrap();
|
||||
|
||||
let latex = to_latex(&expr);
|
||||
assert!(latex.contains("\\sqrt"));
|
||||
|
||||
let mathml = to_mathml(&expr);
|
||||
assert!(mathml.contains("<msqrt>"));
|
||||
|
||||
let asciimath = to_asciimath(&expr);
|
||||
assert!(asciimath.contains("sqrt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_complex_expression() {
|
||||
// Quadratic formula: (-b ± √(b² - 4ac)) / 2a
|
||||
let expr = parse("\\frac{-b + \\sqrt{b^2 - 4*a*c}}{2*a}").unwrap();
|
||||
|
||||
let latex = to_latex(&expr);
|
||||
assert!(latex.contains("\\frac"));
|
||||
assert!(latex.contains("\\sqrt"));
|
||||
|
||||
let mathml = to_mathml(&expr);
|
||||
assert!(mathml.contains("<mfrac>"));
|
||||
assert!(mathml.contains("<msqrt>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_symbol_lookup() {
|
||||
assert!(unicode_to_latex('α').is_some());
|
||||
assert_eq!(unicode_to_latex('α'), Some("alpha"));
|
||||
assert_eq!(unicode_to_latex('π'), Some("pi"));
|
||||
assert_eq!(unicode_to_latex('∑'), Some("sum"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_get_symbol() {
|
||||
let sym = get_symbol('α').unwrap();
|
||||
assert_eq!(sym.latex, "alpha");
|
||||
assert_eq!(sym.category, SymbolCategory::Greek);
|
||||
}
|
||||
}
|
||||
529
examples/scipix/src/math/parser.rs
Normal file
529
examples/scipix/src/math/parser.rs
Normal file
@@ -0,0 +1,529 @@
|
||||
//! Mathematical expression parser
|
||||
//!
|
||||
//! This module parses mathematical expressions from various formats
|
||||
//! including LaTeX, Unicode text, and symbolic notation.
|
||||
|
||||
use crate::math::ast::{BinaryOp, BracketType, LargeOpType, MathExpr, MathNode, UnaryOp};
|
||||
use crate::math::symbols::get_symbol;
|
||||
use nom::{
|
||||
branch::alt,
|
||||
bytes::complete::{tag, take_while, take_while1},
|
||||
character::complete::{alpha1, char, digit1, multispace0},
|
||||
combinator::{map, opt, recognize},
|
||||
multi::{many0, separated_list0},
|
||||
sequence::{delimited, pair, preceded, tuple},
|
||||
IResult,
|
||||
};
|
||||
|
||||
/// Parser for mathematical expressions
|
||||
pub struct Parser {
|
||||
/// Confidence score for parsed expression
|
||||
confidence: f32,
|
||||
}
|
||||
|
||||
impl Parser {
|
||||
/// Create a new parser
|
||||
pub fn new() -> Self {
|
||||
Self { confidence: 1.0 }
|
||||
}
|
||||
|
||||
/// Parse a mathematical expression from string
|
||||
pub fn parse(&mut self, input: &str) -> Result<MathExpr, String> {
|
||||
match self.parse_expression(input) {
|
||||
Ok((_, node)) => Ok(MathExpr::new(node, self.confidence)),
|
||||
Err(e) => Err(format!("Parse error: {:?}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse top-level expression
|
||||
fn parse_expression<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
|
||||
self.parse_relational(input)
|
||||
}
|
||||
|
||||
/// Parse relational operators (=, <, >, etc.)
|
||||
fn parse_relational<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
|
||||
let (input, left) = self.parse_additive(input)?;
|
||||
let (input, op_right) = opt(pair(
|
||||
delimited(
|
||||
multispace0,
|
||||
alt((
|
||||
map(tag("=="), |_| BinaryOp::Equal),
|
||||
map(tag("="), |_| BinaryOp::Equal),
|
||||
map(tag("!="), |_| BinaryOp::NotEqual),
|
||||
map(tag("≠"), |_| BinaryOp::NotEqual),
|
||||
map(tag("<="), |_| BinaryOp::LessEqual),
|
||||
map(tag("≤"), |_| BinaryOp::LessEqual),
|
||||
map(tag(">="), |_| BinaryOp::GreaterEqual),
|
||||
map(tag("≥"), |_| BinaryOp::GreaterEqual),
|
||||
map(tag("<"), |_| BinaryOp::Less),
|
||||
map(tag(">"), |_| BinaryOp::Greater),
|
||||
map(tag("≈"), |_| BinaryOp::ApproxEqual),
|
||||
map(tag("≡"), |_| BinaryOp::Equivalent),
|
||||
map(tag("∼"), |_| BinaryOp::Similar),
|
||||
)),
|
||||
multispace0,
|
||||
),
|
||||
|i| self.parse_additive(i),
|
||||
))(input)?;
|
||||
|
||||
Ok((
|
||||
input,
|
||||
if let Some((op, right)) = op_right {
|
||||
MathNode::Binary {
|
||||
op,
|
||||
left: Box::new(left),
|
||||
right: Box::new(right),
|
||||
}
|
||||
} else {
|
||||
left
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
/// Parse additive operators (+, -)
|
||||
fn parse_additive<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
|
||||
let (input, left) = self.parse_multiplicative(input)?;
|
||||
let (input, ops) = many0(pair(
|
||||
delimited(
|
||||
multispace0,
|
||||
alt((
|
||||
map(char('+'), |_| BinaryOp::Add),
|
||||
map(char('-'), |_| BinaryOp::Subtract),
|
||||
)),
|
||||
multispace0,
|
||||
),
|
||||
|i| self.parse_multiplicative(i),
|
||||
))(input)?;
|
||||
|
||||
Ok((
|
||||
input,
|
||||
ops.into_iter()
|
||||
.fold(left, |acc, (op, right)| MathNode::Binary {
|
||||
op,
|
||||
left: Box::new(acc),
|
||||
right: Box::new(right),
|
||||
}),
|
||||
))
|
||||
}
|
||||
|
||||
/// Parse multiplicative operators (*, /, ×, ÷)
|
||||
fn parse_multiplicative<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
|
||||
let (input, left) = self.parse_power(input)?;
|
||||
let (input, ops) = many0(pair(
|
||||
delimited(
|
||||
multispace0,
|
||||
alt((
|
||||
map(char('*'), |_| BinaryOp::Multiply),
|
||||
map(char('/'), |_| BinaryOp::Divide),
|
||||
map(char('×'), |_| BinaryOp::Multiply),
|
||||
map(char('÷'), |_| BinaryOp::Divide),
|
||||
map(tag("\\times"), |_| BinaryOp::Multiply),
|
||||
map(tag("\\div"), |_| BinaryOp::Divide),
|
||||
map(tag("\\cdot"), |_| BinaryOp::Multiply),
|
||||
)),
|
||||
multispace0,
|
||||
),
|
||||
|i| self.parse_power(i),
|
||||
))(input)?;
|
||||
|
||||
Ok((
|
||||
input,
|
||||
ops.into_iter()
|
||||
.fold(left, |acc, (op, right)| MathNode::Binary {
|
||||
op,
|
||||
left: Box::new(acc),
|
||||
right: Box::new(right),
|
||||
}),
|
||||
))
|
||||
}
|
||||
|
||||
/// Parse power operator (^)
|
||||
fn parse_power<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
|
||||
let (input, base) = self.parse_unary(input)?;
|
||||
let (input, exp) = opt(preceded(
|
||||
delimited(multispace0, char('^'), multispace0),
|
||||
|i| self.parse_unary(i),
|
||||
))(input)?;
|
||||
|
||||
Ok((
|
||||
input,
|
||||
if let Some(exponent) = exp {
|
||||
MathNode::Binary {
|
||||
op: BinaryOp::Power,
|
||||
left: Box::new(base),
|
||||
right: Box::new(exponent),
|
||||
}
|
||||
} else {
|
||||
base
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
/// Parse unary operators (+, -)
|
||||
fn parse_unary<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
|
||||
alt((
|
||||
map(
|
||||
pair(
|
||||
delimited(
|
||||
multispace0,
|
||||
alt((
|
||||
map(char('+'), |_| UnaryOp::Plus),
|
||||
map(char('-'), |_| UnaryOp::Minus),
|
||||
)),
|
||||
multispace0,
|
||||
),
|
||||
|i| self.parse_script(i),
|
||||
),
|
||||
|(op, operand)| MathNode::Unary {
|
||||
op,
|
||||
operand: Box::new(operand),
|
||||
},
|
||||
),
|
||||
|i| self.parse_script(i),
|
||||
))(input)
|
||||
}
|
||||
|
||||
/// Parse subscript/superscript
|
||||
fn parse_script<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
|
||||
let (input, base) = self.parse_primary(input)?;
|
||||
let (input, sub) = opt(preceded(char('_'), |i| self.parse_script_content(i)))(input)?;
|
||||
let (input, sup) = opt(preceded(char('^'), |i| self.parse_script_content(i)))(input)?;
|
||||
|
||||
Ok((
|
||||
input,
|
||||
if sub.is_some() || sup.is_some() {
|
||||
MathNode::Script {
|
||||
base: Box::new(base),
|
||||
subscript: sub.map(Box::new),
|
||||
superscript: sup.map(Box::new),
|
||||
}
|
||||
} else {
|
||||
base
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
/// Parse script content (single char or braced expression)
|
||||
fn parse_script_content<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
|
||||
alt((
|
||||
delimited(char('{'), |i| self.parse_expression(i), char('}')),
|
||||
map(recognize(alpha1), |s: &str| MathNode::Symbol {
|
||||
value: s.to_string(),
|
||||
unicode: s.chars().next(),
|
||||
}),
|
||||
map(digit1, |s: &str| MathNode::Number {
|
||||
value: s.to_string(),
|
||||
is_decimal: false,
|
||||
}),
|
||||
))(input)
|
||||
}
|
||||
|
||||
/// Parse primary expressions (atoms)
|
||||
fn parse_primary<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
|
||||
delimited(
|
||||
multispace0,
|
||||
alt((
|
||||
|i| self.parse_function(i),
|
||||
|i| self.parse_fraction(i),
|
||||
|i| self.parse_radical(i),
|
||||
|i| self.parse_large_op(i),
|
||||
|i| self.parse_greek(i),
|
||||
|i| self.parse_number(i),
|
||||
|i| self.parse_symbol(i),
|
||||
|i| self.parse_grouped(i),
|
||||
)),
|
||||
multispace0,
|
||||
)(input)
|
||||
}
|
||||
|
||||
/// Parse fraction (\frac{a}{b})
|
||||
fn parse_fraction<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
|
||||
let (input, _) = tag("\\frac")(input)?;
|
||||
let (input, num) = delimited(char('{'), |i| self.parse_expression(i), char('}'))(input)?;
|
||||
let (input, den) = delimited(char('{'), |i| self.parse_expression(i), char('}'))(input)?;
|
||||
|
||||
Ok((
|
||||
input,
|
||||
MathNode::Fraction {
|
||||
numerator: Box::new(num),
|
||||
denominator: Box::new(den),
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
/// Parse radical (\sqrt[n]{x})
|
||||
fn parse_radical<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
|
||||
let (input, _) = tag("\\sqrt")(input)?;
|
||||
let (input, index) = opt(delimited(
|
||||
char('['),
|
||||
|i| self.parse_expression(i),
|
||||
char(']'),
|
||||
))(input)?;
|
||||
let (input, radicand) =
|
||||
delimited(char('{'), |i| self.parse_expression(i), char('}'))(input)?;
|
||||
|
||||
Ok((
|
||||
input,
|
||||
MathNode::Radical {
|
||||
index: index.map(Box::new),
|
||||
radicand: Box::new(radicand),
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
/// Parse large operators (sum, integral, etc.)
|
||||
fn parse_large_op<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
|
||||
let (input, op_type) = alt((
|
||||
map(tag("\\sum"), |_| LargeOpType::Sum),
|
||||
map(tag("\\prod"), |_| LargeOpType::Product),
|
||||
map(tag("\\int"), |_| LargeOpType::Integral),
|
||||
map(tag("\\iint"), |_| LargeOpType::DoubleIntegral),
|
||||
map(tag("\\iiint"), |_| LargeOpType::TripleIntegral),
|
||||
map(tag("\\oint"), |_| LargeOpType::ContourIntegral),
|
||||
map(tag("∑"), |_| LargeOpType::Sum),
|
||||
map(tag("∏"), |_| LargeOpType::Product),
|
||||
map(tag("∫"), |_| LargeOpType::Integral),
|
||||
))(input)?;
|
||||
|
||||
let (input, lower) = opt(preceded(
|
||||
char('_'),
|
||||
alt((
|
||||
delimited(char('{'), |i| self.parse_expression(i), char('}')),
|
||||
|i| self.parse_primary(i),
|
||||
)),
|
||||
))(input)?;
|
||||
|
||||
let (input, upper) = opt(preceded(
|
||||
char('^'),
|
||||
alt((
|
||||
delimited(char('{'), |i| self.parse_expression(i), char('}')),
|
||||
|i| self.parse_primary(i),
|
||||
)),
|
||||
))(input)?;
|
||||
|
||||
let (input, content) = self.parse_primary(input)?;
|
||||
|
||||
Ok((
|
||||
input,
|
||||
MathNode::LargeOp {
|
||||
op_type,
|
||||
lower: lower.map(Box::new),
|
||||
upper: upper.map(Box::new),
|
||||
content: Box::new(content),
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
/// Parse function (sin, cos, etc.)
|
||||
fn parse_function<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
|
||||
let (input, _) = char('\\')(input)?;
|
||||
let (input, name) = alpha1(input)?;
|
||||
let (input, _) = multispace0(input)?;
|
||||
let (input, arg) = self.parse_primary(input)?;
|
||||
|
||||
Ok((
|
||||
input,
|
||||
MathNode::Function {
|
||||
name: name.to_string(),
|
||||
argument: Box::new(arg),
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
/// Parse Greek letter
|
||||
fn parse_greek<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
|
||||
let (input, _) = char('\\')(input)?;
|
||||
let (input, name) = alpha1(input)?;
|
||||
|
||||
// Convert LaTeX name to Unicode if possible
|
||||
let unicode = match name {
|
||||
"alpha" => Some('α'),
|
||||
"beta" => Some('β'),
|
||||
"gamma" => Some('γ'),
|
||||
"delta" => Some('δ'),
|
||||
"epsilon" => Some('ε'),
|
||||
"pi" => Some('π'),
|
||||
"theta" => Some('θ'),
|
||||
"lambda" => Some('λ'),
|
||||
"mu" => Some('μ'),
|
||||
"sigma" => Some('σ'),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
Ok((
|
||||
input,
|
||||
MathNode::Symbol {
|
||||
value: name.to_string(),
|
||||
unicode,
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
/// Parse number
|
||||
fn parse_number<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
|
||||
let (input, num_str) = recognize(pair(digit1, opt(pair(char('.'), digit1))))(input)?;
|
||||
let is_decimal = num_str.contains('.');
|
||||
|
||||
Ok((
|
||||
input,
|
||||
MathNode::Number {
|
||||
value: num_str.to_string(),
|
||||
is_decimal,
|
||||
},
|
||||
))
|
||||
}
|
||||
|
||||
/// Parse symbol (variable)
|
||||
fn parse_symbol<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
|
||||
map(take_while1(|c: char| c.is_alphabetic()), |s: &str| {
|
||||
let c = s.chars().next();
|
||||
MathNode::Symbol {
|
||||
value: s.to_string(),
|
||||
unicode: c,
|
||||
}
|
||||
})(input)
|
||||
}
|
||||
|
||||
/// Parse grouped expression (parentheses)
|
||||
fn parse_grouped<'a>(&self, input: &'a str) -> IResult<&'a str, MathNode> {
|
||||
delimited(char('('), |i| self.parse_expression(i), char(')'))(input)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Parser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a mathematical expression from string
|
||||
pub fn parse_expression(input: &str) -> Result<MathExpr, String> {
|
||||
Parser::new().parse(input)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_number() {
|
||||
let expr = parse_expression("42").unwrap();
|
||||
match expr.root {
|
||||
MathNode::Number { value, .. } => assert_eq!(value, "42"),
|
||||
_ => panic!("Expected Number node"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_addition() {
|
||||
let expr = parse_expression("1 + 2").unwrap();
|
||||
match expr.root {
|
||||
MathNode::Binary { op, .. } => assert_eq!(op, BinaryOp::Add),
|
||||
_ => panic!("Expected Binary node"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_multiplication() {
|
||||
let expr = parse_expression("3 * 4").unwrap();
|
||||
match expr.root {
|
||||
MathNode::Binary { op, .. } => assert_eq!(op, BinaryOp::Multiply),
|
||||
_ => panic!("Expected Binary node"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_precedence() {
|
||||
let expr = parse_expression("1 + 2 * 3").unwrap();
|
||||
// Should parse as 1 + (2 * 3)
|
||||
match expr.root {
|
||||
MathNode::Binary {
|
||||
op: BinaryOp::Add,
|
||||
left,
|
||||
right,
|
||||
} => {
|
||||
assert!(matches!(*left, MathNode::Number { .. }));
|
||||
assert!(matches!(
|
||||
*right,
|
||||
MathNode::Binary {
|
||||
op: BinaryOp::Multiply,
|
||||
..
|
||||
}
|
||||
));
|
||||
}
|
||||
_ => panic!("Expected Add with Multiply on right"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_power() {
|
||||
let expr = parse_expression("x^2").unwrap();
|
||||
match expr.root {
|
||||
MathNode::Binary { op, .. } => assert_eq!(op, BinaryOp::Power),
|
||||
_ => panic!("Expected Binary node with power"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_fraction() {
|
||||
let expr = parse_expression("\\frac{1}{2}").unwrap();
|
||||
match expr.root {
|
||||
MathNode::Fraction { .. } => {}
|
||||
_ => panic!("Expected Fraction node"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_sqrt() {
|
||||
let expr = parse_expression("\\sqrt{2}").unwrap();
|
||||
match expr.root {
|
||||
MathNode::Radical { index, .. } => assert!(index.is_none()),
|
||||
_ => panic!("Expected Radical node"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_nth_root() {
|
||||
let expr = parse_expression("\\sqrt[3]{8}").unwrap();
|
||||
match expr.root {
|
||||
MathNode::Radical { index, .. } => assert!(index.is_some()),
|
||||
_ => panic!("Expected Radical node with index"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_subscript() {
|
||||
let expr = parse_expression("a_n").unwrap();
|
||||
match expr.root {
|
||||
MathNode::Script { subscript, .. } => assert!(subscript.is_some()),
|
||||
_ => panic!("Expected Script node"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_superscript() {
|
||||
let expr = parse_expression("x^2").unwrap();
|
||||
match expr.root {
|
||||
MathNode::Binary { op, .. } => assert_eq!(op, BinaryOp::Power),
|
||||
_ => panic!("Expected power operation"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_sum() {
|
||||
let expr = parse_expression("\\sum_{i=1}^{n} i").unwrap();
|
||||
match expr.root {
|
||||
MathNode::LargeOp { op_type, .. } => assert_eq!(op_type, LargeOpType::Sum),
|
||||
_ => panic!("Expected LargeOp node"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_complex() {
|
||||
let expr = parse_expression("\\frac{-b + \\sqrt{b^2 - 4ac}}{2a}").unwrap();
|
||||
match expr.root {
|
||||
MathNode::Fraction { .. } => {}
|
||||
_ => panic!("Expected Fraction node"),
|
||||
}
|
||||
}
|
||||
}
|
||||
1247
examples/scipix/src/math/symbols.rs
Normal file
1247
examples/scipix/src/math/symbols.rs
Normal file
File diff suppressed because it is too large
Load Diff
384
examples/scipix/src/ocr/confidence.rs
Normal file
384
examples/scipix/src/ocr/confidence.rs
Normal file
@@ -0,0 +1,384 @@
|
||||
//! Confidence Scoring Module
|
||||
//!
|
||||
//! This module provides confidence scoring and calibration for OCR results.
|
||||
//! It includes per-character confidence calculation and aggregation methods.
|
||||
|
||||
use super::Result;
|
||||
use std::collections::HashMap;
|
||||
use tracing::debug;
|
||||
|
||||
/// Calculate confidence score for a single character prediction
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `logits` - Raw logits from the model for this character position
|
||||
///
|
||||
/// # Returns
|
||||
/// Confidence score between 0.0 and 1.0
|
||||
pub fn calculate_confidence(logits: &[f32]) -> f32 {
|
||||
if logits.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Apply softmax to get probabilities
|
||||
let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
|
||||
let exp_sum: f32 = logits.iter().map(|&x| (x - max_logit).exp()).sum();
|
||||
|
||||
// Return the maximum probability
|
||||
let max_prob = logits
|
||||
.iter()
|
||||
.map(|&x| (x - max_logit).exp() / exp_sum)
|
||||
.fold(0.0f32, |a, b| a.max(b));
|
||||
|
||||
max_prob.clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
/// Aggregate multiple confidence scores into a single score
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `confidences` - Individual confidence scores
|
||||
///
|
||||
/// # Returns
|
||||
/// Aggregated confidence score using geometric mean
|
||||
pub fn aggregate_confidence(confidences: &[f32]) -> f32 {
|
||||
if confidences.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Use geometric mean for aggregation (more conservative than arithmetic mean)
|
||||
let product: f32 = confidences.iter().product();
|
||||
let n = confidences.len() as f32;
|
||||
product.powf(1.0 / n).clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
/// Alternative aggregation using arithmetic mean
|
||||
pub fn aggregate_confidence_mean(confidences: &[f32]) -> f32 {
|
||||
if confidences.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let sum: f32 = confidences.iter().sum();
|
||||
(sum / confidences.len() as f32).clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
/// Alternative aggregation using minimum (most conservative)
|
||||
pub fn aggregate_confidence_min(confidences: &[f32]) -> f32 {
|
||||
confidences
|
||||
.iter()
|
||||
.fold(1.0f32, |a, &b| a.min(b))
|
||||
.clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
/// Alternative aggregation using harmonic mean
|
||||
pub fn aggregate_confidence_harmonic(confidences: &[f32]) -> f32 {
|
||||
if confidences.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let sum_reciprocals: f32 = confidences.iter().map(|&c| 1.0 / c.max(0.001)).sum();
|
||||
let n = confidences.len() as f32;
|
||||
(n / sum_reciprocals).clamp(0.0, 1.0)
|
||||
}
|
||||
|
||||
/// Confidence calibrator using isotonic regression
|
||||
///
|
||||
/// This calibrator learns a mapping from raw confidence scores to calibrated
|
||||
/// probabilities using historical data.
|
||||
pub struct ConfidenceCalibrator {
|
||||
/// Calibration mapping: raw_score -> calibrated_score
|
||||
calibration_map: HashMap<u8, f32>, // Use u8 for binned scores (0-100)
|
||||
/// Whether the calibrator has been trained
|
||||
is_trained: bool,
|
||||
}
|
||||
|
||||
impl ConfidenceCalibrator {
|
||||
/// Create a new, untrained calibrator
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
calibration_map: HashMap::new(),
|
||||
is_trained: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Train the calibrator on labeled data
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `predictions` - Raw confidence scores from the model
|
||||
/// * `ground_truth` - Binary labels (1.0 if correct, 0.0 if incorrect)
|
||||
pub fn train(&mut self, predictions: &[f32], ground_truth: &[f32]) -> Result<()> {
|
||||
debug!(
|
||||
"Training confidence calibrator on {} samples",
|
||||
predictions.len()
|
||||
);
|
||||
|
||||
if predictions.len() != ground_truth.len() {
|
||||
return Err(super::OcrError::InvalidConfig(
|
||||
"Predictions and ground truth must have same length".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
if predictions.is_empty() {
|
||||
return Err(super::OcrError::InvalidConfig(
|
||||
"Cannot train on empty data".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Bin the scores (0.0-1.0 -> 0-100)
|
||||
let mut bins: HashMap<u8, Vec<f32>> = HashMap::new();
|
||||
|
||||
for (&pred, &truth) in predictions.iter().zip(ground_truth.iter()) {
|
||||
let bin = (pred * 100.0).clamp(0.0, 100.0) as u8;
|
||||
bins.entry(bin).or_insert_with(Vec::new).push(truth);
|
||||
}
|
||||
|
||||
// Calculate mean accuracy for each bin
|
||||
self.calibration_map.clear();
|
||||
for (bin, truths) in bins {
|
||||
let mean_accuracy = truths.iter().sum::<f32>() / truths.len() as f32;
|
||||
self.calibration_map.insert(bin, mean_accuracy);
|
||||
}
|
||||
|
||||
// Perform isotonic regression (simplified version)
|
||||
self.enforce_monotonicity();
|
||||
|
||||
self.is_trained = true;
|
||||
debug!(
|
||||
"Calibrator trained with {} bins",
|
||||
self.calibration_map.len()
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Enforce monotonicity constraint (isotonic regression)
|
||||
fn enforce_monotonicity(&mut self) {
|
||||
let mut sorted_bins: Vec<_> = self.calibration_map.iter().collect();
|
||||
sorted_bins.sort_by_key(|(bin, _)| *bin);
|
||||
|
||||
// Simple isotonic regression: ensure calibrated scores are non-decreasing
|
||||
let mut adjusted = HashMap::new();
|
||||
let mut prev_value = 0.0;
|
||||
|
||||
for (&bin, &value) in sorted_bins {
|
||||
let adjusted_value = value.max(prev_value);
|
||||
adjusted.insert(bin, adjusted_value);
|
||||
prev_value = adjusted_value;
|
||||
}
|
||||
|
||||
self.calibration_map = adjusted;
|
||||
}
|
||||
|
||||
/// Calibrate a raw confidence score
|
||||
pub fn calibrate(&self, raw_score: f32) -> f32 {
|
||||
if !self.is_trained {
|
||||
// If not trained, return raw score
|
||||
return raw_score.clamp(0.0, 1.0);
|
||||
}
|
||||
|
||||
let bin = (raw_score * 100.0).clamp(0.0, 100.0) as u8;
|
||||
|
||||
// Look up calibrated score, or interpolate
|
||||
if let Some(&calibrated) = self.calibration_map.get(&bin) {
|
||||
return calibrated;
|
||||
}
|
||||
|
||||
// Interpolate between nearest bins
|
||||
self.interpolate(bin)
|
||||
}
|
||||
|
||||
/// Interpolate calibrated score for a bin without direct mapping
|
||||
fn interpolate(&self, target_bin: u8) -> f32 {
|
||||
let mut lower = None;
|
||||
let mut upper = None;
|
||||
|
||||
for &bin in self.calibration_map.keys() {
|
||||
if bin < target_bin {
|
||||
lower = Some(lower.map_or(bin, |l: u8| l.max(bin)));
|
||||
} else if bin > target_bin {
|
||||
upper = Some(upper.map_or(bin, |u: u8| u.min(bin)));
|
||||
}
|
||||
}
|
||||
|
||||
match (lower, upper) {
|
||||
(Some(l), Some(u)) => {
|
||||
let l_val = self.calibration_map[&l];
|
||||
let u_val = self.calibration_map[&u];
|
||||
let alpha = (target_bin - l) as f32 / (u - l) as f32;
|
||||
l_val + alpha * (u_val - l_val)
|
||||
}
|
||||
(Some(l), None) => self.calibration_map[&l],
|
||||
(None, Some(u)) => self.calibration_map[&u],
|
||||
(None, None) => target_bin as f32 / 100.0, // Fallback
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if the calibrator is trained
|
||||
pub fn is_trained(&self) -> bool {
|
||||
self.is_trained
|
||||
}
|
||||
|
||||
/// Reset the calibrator
|
||||
pub fn reset(&mut self) {
|
||||
self.calibration_map.clear();
|
||||
self.is_trained = false;
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ConfidenceCalibrator {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate Expected Calibration Error (ECE)
|
||||
///
|
||||
/// Measures the difference between predicted confidence and actual accuracy
|
||||
pub fn calculate_ece(predictions: &[f32], ground_truth: &[f32], n_bins: usize) -> f32 {
|
||||
if predictions.len() != ground_truth.len() || predictions.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mut bins: Vec<Vec<(f32, f32)>> = vec![Vec::new(); n_bins];
|
||||
|
||||
// Assign predictions to bins
|
||||
for (&pred, &truth) in predictions.iter().zip(ground_truth.iter()) {
|
||||
let bin_idx = ((pred * n_bins as f32) as usize).min(n_bins - 1);
|
||||
bins[bin_idx].push((pred, truth));
|
||||
}
|
||||
|
||||
// Calculate ECE
|
||||
let mut ece = 0.0;
|
||||
let total = predictions.len() as f32;
|
||||
|
||||
for bin in bins {
|
||||
if bin.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let bin_size = bin.len() as f32;
|
||||
let avg_confidence: f32 = bin.iter().map(|(p, _)| p).sum::<f32>() / bin_size;
|
||||
let avg_accuracy: f32 = bin.iter().map(|(_, t)| t).sum::<f32>() / bin_size;
|
||||
|
||||
ece += (bin_size / total) * (avg_confidence - avg_accuracy).abs();
|
||||
}
|
||||
|
||||
ece
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_calculate_confidence() {
|
||||
let logits = vec![1.0, 5.0, 2.0, 1.0];
|
||||
let conf = calculate_confidence(&logits);
|
||||
assert!(conf > 0.5);
|
||||
assert!(conf <= 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calculate_confidence_empty() {
|
||||
let logits: Vec<f32> = vec![];
|
||||
let conf = calculate_confidence(&logits);
|
||||
assert_eq!(conf, 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregate_confidence() {
|
||||
let confidences = vec![0.9, 0.8, 0.95, 0.85];
|
||||
let agg = aggregate_confidence(&confidences);
|
||||
assert!(agg > 0.0 && agg <= 1.0);
|
||||
assert!(agg < 0.9); // Geometric mean should be less than max
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregate_confidence_mean() {
|
||||
let confidences = vec![0.8, 0.9, 0.7];
|
||||
let mean = aggregate_confidence_mean(&confidences);
|
||||
assert_eq!(mean, 0.8); // (0.8 + 0.9 + 0.7) / 3
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregate_confidence_min() {
|
||||
let confidences = vec![0.9, 0.7, 0.95];
|
||||
let min = aggregate_confidence_min(&confidences);
|
||||
assert_eq!(min, 0.7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_aggregate_confidence_harmonic() {
|
||||
let confidences = vec![0.5, 0.5];
|
||||
let harmonic = aggregate_confidence_harmonic(&confidences);
|
||||
assert_eq!(harmonic, 0.5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calibrator_training() {
|
||||
let mut calibrator = ConfidenceCalibrator::new();
|
||||
assert!(!calibrator.is_trained());
|
||||
|
||||
let predictions = vec![0.9, 0.8, 0.7, 0.6, 0.5];
|
||||
let ground_truth = vec![1.0, 1.0, 0.0, 1.0, 0.0];
|
||||
|
||||
let result = calibrator.train(&predictions, &ground_truth);
|
||||
assert!(result.is_ok());
|
||||
assert!(calibrator.is_trained());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calibrator_calibrate() {
|
||||
let mut calibrator = ConfidenceCalibrator::new();
|
||||
|
||||
// Before training, should return raw score
|
||||
assert_eq!(calibrator.calibrate(0.8), 0.8);
|
||||
|
||||
// Train with some data
|
||||
let predictions = vec![0.9, 0.9, 0.8, 0.8, 0.7, 0.7];
|
||||
let ground_truth = vec![1.0, 1.0, 1.0, 0.0, 0.0, 0.0];
|
||||
calibrator.train(&predictions, &ground_truth).unwrap();
|
||||
|
||||
// After training, should return calibrated score
|
||||
let calibrated = calibrator.calibrate(0.85);
|
||||
assert!(calibrated >= 0.0 && calibrated <= 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calibrator_reset() {
|
||||
let mut calibrator = ConfidenceCalibrator::new();
|
||||
let predictions = vec![0.9, 0.8];
|
||||
let ground_truth = vec![1.0, 0.0];
|
||||
|
||||
calibrator.train(&predictions, &ground_truth).unwrap();
|
||||
assert!(calibrator.is_trained());
|
||||
|
||||
calibrator.reset();
|
||||
assert!(!calibrator.is_trained());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calculate_ece() {
|
||||
let predictions = vec![0.9, 0.7, 0.6, 0.8];
|
||||
let ground_truth = vec![1.0, 1.0, 0.0, 1.0];
|
||||
let ece = calculate_ece(&predictions, &ground_truth, 3);
|
||||
assert!(ece >= 0.0 && ece <= 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calibrator_monotonicity() {
|
||||
let mut calibrator = ConfidenceCalibrator::new();
|
||||
|
||||
// Create data that would violate monotonicity
|
||||
let predictions = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9];
|
||||
let ground_truth = vec![0.2, 0.3, 0.2, 0.5, 0.4, 0.7, 0.8, 0.9, 1.0];
|
||||
|
||||
calibrator.train(&predictions, &ground_truth).unwrap();
|
||||
|
||||
// Check monotonicity
|
||||
let score1 = calibrator.calibrate(0.3);
|
||||
let score2 = calibrator.calibrate(0.5);
|
||||
let score3 = calibrator.calibrate(0.7);
|
||||
|
||||
assert!(score2 >= score1, "Calibrated scores should be monotonic");
|
||||
assert!(score3 >= score2, "Calibrated scores should be monotonic");
|
||||
}
|
||||
}
|
||||
441
examples/scipix/src/ocr/decoder.rs
Normal file
441
examples/scipix/src/ocr/decoder.rs
Normal file
@@ -0,0 +1,441 @@
|
||||
//! Output Decoding Module
|
||||
//!
|
||||
//! This module provides various decoding strategies for converting
|
||||
//! model output logits into text strings.
|
||||
|
||||
use super::{OcrError, Result};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tracing::debug;
|
||||
|
||||
/// Decoder trait for converting logits to text
|
||||
pub trait Decoder: Send + Sync {
|
||||
/// Decode logits to text
|
||||
fn decode(&self, logits: &[Vec<f32>]) -> Result<String>;
|
||||
|
||||
/// Decode with confidence scores per character
|
||||
fn decode_with_confidence(&self, logits: &[Vec<f32>]) -> Result<(String, Vec<f32>)> {
|
||||
// Default implementation just returns uniform confidence
|
||||
let text = self.decode(logits)?;
|
||||
let confidences = vec![1.0; text.len()];
|
||||
Ok((text, confidences))
|
||||
}
|
||||
}
|
||||
|
||||
/// Vocabulary mapping for character recognition
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Vocabulary {
|
||||
/// Index to character mapping
|
||||
idx_to_char: HashMap<usize, char>,
|
||||
/// Character to index mapping
|
||||
char_to_idx: HashMap<char, usize>,
|
||||
/// Blank token index for CTC
|
||||
blank_idx: usize,
|
||||
}
|
||||
|
||||
impl Vocabulary {
|
||||
/// Create a new vocabulary
|
||||
pub fn new(chars: Vec<char>, blank_idx: usize) -> Self {
|
||||
let idx_to_char: HashMap<usize, char> =
|
||||
chars.iter().enumerate().map(|(i, &c)| (i, c)).collect();
|
||||
let char_to_idx: HashMap<char, usize> =
|
||||
chars.iter().enumerate().map(|(i, &c)| (c, i)).collect();
|
||||
|
||||
Self {
|
||||
idx_to_char,
|
||||
char_to_idx,
|
||||
blank_idx,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get character by index
|
||||
pub fn get_char(&self, idx: usize) -> Option<char> {
|
||||
self.idx_to_char.get(&idx).copied()
|
||||
}
|
||||
|
||||
/// Get index by character
|
||||
pub fn get_idx(&self, ch: char) -> Option<usize> {
|
||||
self.char_to_idx.get(&ch).copied()
|
||||
}
|
||||
|
||||
/// Get blank token index
|
||||
pub fn blank_idx(&self) -> usize {
|
||||
self.blank_idx
|
||||
}
|
||||
|
||||
/// Get vocabulary size
|
||||
pub fn size(&self) -> usize {
|
||||
self.idx_to_char.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Vocabulary {
|
||||
fn default() -> Self {
|
||||
// Default vocabulary: lowercase letters + digits + space + blank
|
||||
let mut chars = Vec::new();
|
||||
|
||||
// Add lowercase letters
|
||||
for c in 'a'..='z' {
|
||||
chars.push(c);
|
||||
}
|
||||
|
||||
// Add digits
|
||||
for c in '0'..='9' {
|
||||
chars.push(c);
|
||||
}
|
||||
|
||||
// Add space
|
||||
chars.push(' ');
|
||||
|
||||
// Blank token is at the end
|
||||
let blank_idx = chars.len();
|
||||
|
||||
Self::new(chars, blank_idx)
|
||||
}
|
||||
}
|
||||
|
||||
/// Greedy decoder - selects the character with highest probability at each step
|
||||
pub struct GreedyDecoder {
|
||||
vocabulary: Arc<Vocabulary>,
|
||||
}
|
||||
|
||||
impl GreedyDecoder {
|
||||
/// Create a new greedy decoder
|
||||
pub fn new(vocabulary: Arc<Vocabulary>) -> Self {
|
||||
Self { vocabulary }
|
||||
}
|
||||
|
||||
/// Find the index with maximum value in a slice
|
||||
fn argmax(values: &[f32]) -> usize {
|
||||
values
|
||||
.iter()
|
||||
.enumerate()
|
||||
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
|
||||
.map(|(idx, _)| idx)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for GreedyDecoder {
|
||||
fn decode(&self, logits: &[Vec<f32>]) -> Result<String> {
|
||||
debug!("Greedy decoding {} frames", logits.len());
|
||||
|
||||
let mut result = String::new();
|
||||
let mut prev_idx = None;
|
||||
|
||||
for frame_logits in logits {
|
||||
let idx = Self::argmax(frame_logits);
|
||||
|
||||
// Skip blank tokens and repeated characters
|
||||
if idx != self.vocabulary.blank_idx() && Some(idx) != prev_idx {
|
||||
if let Some(ch) = self.vocabulary.get_char(idx) {
|
||||
result.push(ch);
|
||||
}
|
||||
}
|
||||
|
||||
prev_idx = Some(idx);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn decode_with_confidence(&self, logits: &[Vec<f32>]) -> Result<(String, Vec<f32>)> {
|
||||
let mut result = String::new();
|
||||
let mut confidences = Vec::new();
|
||||
let mut prev_idx = None;
|
||||
|
||||
for frame_logits in logits {
|
||||
let idx = Self::argmax(frame_logits);
|
||||
let confidence = softmax_max(frame_logits);
|
||||
|
||||
// Skip blank tokens and repeated characters
|
||||
if idx != self.vocabulary.blank_idx() && Some(idx) != prev_idx {
|
||||
if let Some(ch) = self.vocabulary.get_char(idx) {
|
||||
result.push(ch);
|
||||
confidences.push(confidence);
|
||||
}
|
||||
}
|
||||
|
||||
prev_idx = Some(idx);
|
||||
}
|
||||
|
||||
Ok((result, confidences))
|
||||
}
|
||||
}
|
||||
|
||||
/// Beam search decoder - maintains top-k hypotheses for better accuracy
|
||||
pub struct BeamSearchDecoder {
|
||||
vocabulary: Arc<Vocabulary>,
|
||||
beam_width: usize,
|
||||
}
|
||||
|
||||
impl BeamSearchDecoder {
|
||||
/// Create a new beam search decoder
|
||||
pub fn new(vocabulary: Arc<Vocabulary>, beam_width: usize) -> Self {
|
||||
Self {
|
||||
vocabulary,
|
||||
beam_width: beam_width.max(1),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get beam width
|
||||
pub fn beam_width(&self) -> usize {
|
||||
self.beam_width
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for BeamSearchDecoder {
|
||||
fn decode(&self, logits: &[Vec<f32>]) -> Result<String> {
|
||||
debug!(
|
||||
"Beam search decoding {} frames (beam_width: {})",
|
||||
logits.len(),
|
||||
self.beam_width
|
||||
);
|
||||
|
||||
if logits.is_empty() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
|
||||
// Initialize beams: (text, score, last_idx)
|
||||
let mut beams: Vec<(String, f32, Option<usize>)> = vec![(String::new(), 0.0, None)];
|
||||
|
||||
for frame_logits in logits {
|
||||
let mut new_beams = Vec::new();
|
||||
|
||||
for (text, score, last_idx) in &beams {
|
||||
// Get top-k predictions for this frame
|
||||
let mut indexed_logits: Vec<(usize, f32)> = frame_logits
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, &v)| (i, v))
|
||||
.collect();
|
||||
indexed_logits.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
|
||||
// Expand each beam with top-k predictions
|
||||
for (idx, logit) in indexed_logits.iter().take(self.beam_width) {
|
||||
let new_score = score + logit;
|
||||
|
||||
// Skip blank tokens
|
||||
if *idx == self.vocabulary.blank_idx() {
|
||||
new_beams.push((text.clone(), new_score, Some(*idx)));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip repeated characters (CTC collapse)
|
||||
if Some(*idx) == *last_idx {
|
||||
new_beams.push((text.clone(), new_score, Some(*idx)));
|
||||
continue;
|
||||
}
|
||||
|
||||
// Add character to beam
|
||||
if let Some(ch) = self.vocabulary.get_char(*idx) {
|
||||
let mut new_text = text.clone();
|
||||
new_text.push(ch);
|
||||
new_beams.push((new_text, new_score, Some(*idx)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Keep top beam_width beams
|
||||
new_beams.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
|
||||
new_beams.truncate(self.beam_width);
|
||||
beams = new_beams;
|
||||
}
|
||||
|
||||
// Return the best beam
|
||||
Ok(beams
|
||||
.first()
|
||||
.map(|(text, _, _)| text.clone())
|
||||
.unwrap_or_default())
|
||||
}
|
||||
}
|
||||
|
||||
/// CTC (Connectionist Temporal Classification) decoder
|
||||
pub struct CTCDecoder {
|
||||
vocabulary: Arc<Vocabulary>,
|
||||
}
|
||||
|
||||
impl CTCDecoder {
|
||||
/// Create a new CTC decoder
|
||||
pub fn new(vocabulary: Arc<Vocabulary>) -> Self {
|
||||
Self { vocabulary }
|
||||
}
|
||||
|
||||
/// Collapse repeated characters and remove blanks
|
||||
fn collapse_repeats(&self, indices: &[usize]) -> Vec<usize> {
|
||||
let mut result = Vec::new();
|
||||
let mut prev_idx = None;
|
||||
|
||||
for &idx in indices {
|
||||
// Skip blanks
|
||||
if idx == self.vocabulary.blank_idx() {
|
||||
prev_idx = Some(idx);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip repeats
|
||||
if Some(idx) != prev_idx {
|
||||
result.push(idx);
|
||||
}
|
||||
|
||||
prev_idx = Some(idx);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl Decoder for CTCDecoder {
|
||||
fn decode(&self, logits: &[Vec<f32>]) -> Result<String> {
|
||||
debug!("CTC decoding {} frames", logits.len());
|
||||
|
||||
// Get best path (greedy)
|
||||
let indices: Vec<usize> = logits
|
||||
.iter()
|
||||
.map(|frame| GreedyDecoder::argmax(frame))
|
||||
.collect();
|
||||
|
||||
// Collapse repeats and remove blanks
|
||||
let collapsed = self.collapse_repeats(&indices);
|
||||
|
||||
// Convert to text
|
||||
let text: String = collapsed
|
||||
.iter()
|
||||
.filter_map(|&idx| self.vocabulary.get_char(idx))
|
||||
.collect();
|
||||
|
||||
Ok(text)
|
||||
}
|
||||
|
||||
fn decode_with_confidence(&self, logits: &[Vec<f32>]) -> Result<(String, Vec<f32>)> {
|
||||
let indices: Vec<usize> = logits
|
||||
.iter()
|
||||
.map(|frame| GreedyDecoder::argmax(frame))
|
||||
.collect();
|
||||
let confidences: Vec<f32> = logits.iter().map(|frame| softmax_max(frame)).collect();
|
||||
|
||||
let collapsed = self.collapse_repeats(&indices);
|
||||
|
||||
let text: String = collapsed
|
||||
.iter()
|
||||
.filter_map(|&idx| self.vocabulary.get_char(idx))
|
||||
.collect();
|
||||
|
||||
// Map confidences to non-collapsed positions
|
||||
let mut result_confidences = Vec::new();
|
||||
let mut prev_idx = None;
|
||||
let mut confidence_idx = 0;
|
||||
|
||||
for &idx in &indices {
|
||||
if idx != self.vocabulary.blank_idx() && Some(idx) != prev_idx {
|
||||
if confidence_idx < confidences.len() {
|
||||
result_confidences.push(confidences[confidence_idx]);
|
||||
}
|
||||
}
|
||||
confidence_idx += 1;
|
||||
prev_idx = Some(idx);
|
||||
}
|
||||
|
||||
Ok((text, result_confidences))
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate softmax and return max probability
|
||||
fn softmax_max(logits: &[f32]) -> f32 {
|
||||
if logits.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
|
||||
let exp_sum: f32 = logits.iter().map(|&x| (x - max_logit).exp()).sum();
|
||||
|
||||
let max_exp = (logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b)) - max_logit).exp();
|
||||
max_exp / exp_sum
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_test_vocabulary() -> Arc<Vocabulary> {
|
||||
Arc::new(Vocabulary::default())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vocabulary_default() {
|
||||
let vocab = Vocabulary::default();
|
||||
assert!(vocab.size() > 0);
|
||||
assert_eq!(vocab.get_char(0), Some('a'));
|
||||
assert_eq!(vocab.get_idx('a'), Some(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_greedy_decoder() {
|
||||
let vocab = create_test_vocabulary();
|
||||
let decoder = GreedyDecoder::new(vocab.clone());
|
||||
|
||||
// Mock logits for "hi"
|
||||
let h_idx = vocab.get_idx('h').unwrap();
|
||||
let i_idx = vocab.get_idx('i').unwrap();
|
||||
let blank = vocab.blank_idx();
|
||||
|
||||
let mut logits = vec![
|
||||
vec![0.0; vocab.size() + 1],
|
||||
vec![0.0; vocab.size() + 1],
|
||||
vec![0.0; vocab.size() + 1],
|
||||
];
|
||||
|
||||
logits[0][h_idx] = 10.0;
|
||||
logits[1][blank] = 10.0;
|
||||
logits[2][i_idx] = 10.0;
|
||||
|
||||
let result = decoder.decode(&logits).unwrap();
|
||||
assert_eq!(result, "hi");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_beam_search_decoder() {
|
||||
let vocab = create_test_vocabulary();
|
||||
let decoder = BeamSearchDecoder::new(vocab.clone(), 3);
|
||||
|
||||
assert_eq!(decoder.beam_width(), 3);
|
||||
|
||||
let logits = vec![vec![0.0; vocab.size() + 1]; 5];
|
||||
let result = decoder.decode(&logits);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ctc_decoder() {
|
||||
let vocab = create_test_vocabulary();
|
||||
let decoder = CTCDecoder::new(vocab.clone());
|
||||
|
||||
// Test collapse repeats
|
||||
let a_idx = vocab.get_idx('a').unwrap();
|
||||
let b_idx = vocab.get_idx('b').unwrap();
|
||||
let blank = vocab.blank_idx();
|
||||
|
||||
let indices = vec![a_idx, a_idx, blank, b_idx, b_idx, b_idx];
|
||||
let collapsed = decoder.collapse_repeats(&indices);
|
||||
|
||||
assert_eq!(collapsed, vec![a_idx, b_idx]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_softmax_max() {
|
||||
let logits = vec![1.0, 2.0, 3.0, 2.0, 1.0];
|
||||
let max_prob = softmax_max(&logits);
|
||||
assert!(max_prob > 0.0 && max_prob <= 1.0);
|
||||
assert!(max_prob > 0.5); // The max should have high probability
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_logits() {
|
||||
let vocab = create_test_vocabulary();
|
||||
let decoder = GreedyDecoder::new(vocab);
|
||||
|
||||
let empty_logits: Vec<Vec<f32>> = vec![];
|
||||
let result = decoder.decode(&empty_logits).unwrap();
|
||||
assert_eq!(result, "");
|
||||
}
|
||||
}
|
||||
363
examples/scipix/src/ocr/engine.rs
Normal file
363
examples/scipix/src/ocr/engine.rs
Normal file
@@ -0,0 +1,363 @@
|
||||
//! OCR Engine Implementation
|
||||
//!
|
||||
//! This module provides the main OcrEngine for orchestrating OCR operations.
|
||||
//! It handles model loading, inference coordination, and result assembly.
|
||||
|
||||
use super::{
|
||||
confidence::aggregate_confidence,
|
||||
decoder::{BeamSearchDecoder, CTCDecoder, Decoder, GreedyDecoder, Vocabulary},
|
||||
inference::{DetectionResult, InferenceEngine, RecognitionResult},
|
||||
models::{ModelHandle, ModelRegistry},
|
||||
Character, DecoderType, OcrError, OcrOptions, OcrResult, RegionType, Result, TextRegion,
|
||||
};
|
||||
use parking_lot::RwLock;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
/// OCR processor trait for custom implementations
|
||||
pub trait OcrProcessor: Send + Sync {
|
||||
/// Process an image and return OCR results
|
||||
fn process(&self, image_data: &[u8], options: &OcrOptions) -> Result<OcrResult>;
|
||||
|
||||
/// Batch process multiple images
|
||||
fn process_batch(&self, images: &[&[u8]], options: &OcrOptions) -> Result<Vec<OcrResult>>;
|
||||
}
|
||||
|
||||
/// Main OCR Engine with thread-safe model management
|
||||
pub struct OcrEngine {
|
||||
/// Model registry for loading and caching models
|
||||
registry: Arc<RwLock<ModelRegistry>>,
|
||||
|
||||
/// Inference engine for running ONNX models
|
||||
inference: Arc<InferenceEngine>,
|
||||
|
||||
/// Default OCR options
|
||||
default_options: OcrOptions,
|
||||
|
||||
/// Vocabulary for decoding
|
||||
vocabulary: Arc<Vocabulary>,
|
||||
|
||||
/// Whether the engine is warmed up
|
||||
warmed_up: Arc<RwLock<bool>>,
|
||||
}
|
||||
|
||||
impl OcrEngine {
|
||||
/// Create a new OCR engine with default models
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```no_run
|
||||
/// # use ruvector_scipix::ocr::OcrEngine;
|
||||
/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
|
||||
/// let engine = OcrEngine::new().await?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
pub async fn new() -> Result<Self> {
|
||||
Self::with_options(OcrOptions::default()).await
|
||||
}
|
||||
|
||||
/// Create a new OCR engine with custom options
|
||||
pub async fn with_options(options: OcrOptions) -> Result<Self> {
|
||||
info!("Initializing OCR engine with options: {:?}", options);
|
||||
|
||||
// Initialize model registry
|
||||
let registry = Arc::new(RwLock::new(ModelRegistry::new()));
|
||||
|
||||
// Load default models (in production, these would be downloaded/cached)
|
||||
debug!("Loading detection model...");
|
||||
let detection_model = registry.write().load_detection_model().await.map_err(|e| {
|
||||
OcrError::ModelLoading(format!("Failed to load detection model: {}", e))
|
||||
})?;
|
||||
|
||||
debug!("Loading recognition model...");
|
||||
let recognition_model = registry
|
||||
.write()
|
||||
.load_recognition_model()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
OcrError::ModelLoading(format!("Failed to load recognition model: {}", e))
|
||||
})?;
|
||||
|
||||
let math_model =
|
||||
if options.enable_math {
|
||||
debug!("Loading math recognition model...");
|
||||
Some(registry.write().load_math_model().await.map_err(|e| {
|
||||
OcrError::ModelLoading(format!("Failed to load math model: {}", e))
|
||||
})?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Create inference engine
|
||||
let inference = Arc::new(InferenceEngine::new(
|
||||
detection_model,
|
||||
recognition_model,
|
||||
math_model,
|
||||
options.use_gpu,
|
||||
)?);
|
||||
|
||||
// Load vocabulary
|
||||
let vocabulary = Arc::new(Vocabulary::default());
|
||||
|
||||
let engine = Self {
|
||||
registry,
|
||||
inference,
|
||||
default_options: options,
|
||||
vocabulary,
|
||||
warmed_up: Arc::new(RwLock::new(false)),
|
||||
};
|
||||
|
||||
info!("OCR engine initialized successfully");
|
||||
Ok(engine)
|
||||
}
|
||||
|
||||
/// Warm up the engine by running a dummy inference
|
||||
///
|
||||
/// This helps reduce latency for the first real inference by initializing
|
||||
/// all ONNX runtime resources.
|
||||
pub async fn warmup(&self) -> Result<()> {
|
||||
if *self.warmed_up.read() {
|
||||
debug!("Engine already warmed up, skipping");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
info!("Warming up OCR engine...");
|
||||
let start = Instant::now();
|
||||
|
||||
// Create a small dummy image (100x100 black image)
|
||||
let dummy_image = vec![0u8; 100 * 100 * 3];
|
||||
|
||||
// Run a dummy inference
|
||||
let _ = self.recognize(&dummy_image).await;
|
||||
|
||||
*self.warmed_up.write() = true;
|
||||
info!("Engine warmup completed in {:?}", start.elapsed());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Recognize text in an image using default options
|
||||
pub async fn recognize(&self, image_data: &[u8]) -> Result<OcrResult> {
|
||||
self.recognize_with_options(image_data, &self.default_options)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Recognize text in an image with custom options
|
||||
pub async fn recognize_with_options(
|
||||
&self,
|
||||
image_data: &[u8],
|
||||
options: &OcrOptions,
|
||||
) -> Result<OcrResult> {
|
||||
let start = Instant::now();
|
||||
debug!("Starting OCR recognition");
|
||||
|
||||
// Step 1: Run text detection
|
||||
debug!("Running text detection...");
|
||||
let detection_results = self
|
||||
.inference
|
||||
.run_detection(image_data, options.detection_threshold)
|
||||
.await?;
|
||||
|
||||
debug!("Detected {} regions", detection_results.len());
|
||||
|
||||
if detection_results.is_empty() {
|
||||
warn!("No text regions detected");
|
||||
return Ok(OcrResult {
|
||||
text: String::new(),
|
||||
confidence: 0.0,
|
||||
regions: vec![],
|
||||
has_math: false,
|
||||
processing_time_ms: start.elapsed().as_millis() as u64,
|
||||
});
|
||||
}
|
||||
|
||||
// Step 2: Run recognition on each detected region
|
||||
debug!("Running text recognition...");
|
||||
let mut text_regions = Vec::new();
|
||||
let mut has_math = false;
|
||||
|
||||
for detection in detection_results {
|
||||
// Determine region type
|
||||
let region_type = if options.enable_math && detection.is_math_likely {
|
||||
has_math = true;
|
||||
RegionType::Math
|
||||
} else {
|
||||
RegionType::Text
|
||||
};
|
||||
|
||||
// Run appropriate recognition
|
||||
let recognition = if region_type == RegionType::Math {
|
||||
self.inference
|
||||
.run_math_recognition(&detection.region_image, options)
|
||||
.await?
|
||||
} else {
|
||||
self.inference
|
||||
.run_recognition(&detection.region_image, options)
|
||||
.await?
|
||||
};
|
||||
|
||||
// Decode the recognition output
|
||||
let decoded_text = self.decode_output(&recognition, options)?;
|
||||
|
||||
// Calculate confidence
|
||||
let confidence = aggregate_confidence(&recognition.character_confidences);
|
||||
|
||||
// Filter by confidence threshold
|
||||
if confidence < options.recognition_threshold {
|
||||
debug!(
|
||||
"Skipping region with low confidence: {:.2} < {:.2}",
|
||||
confidence, options.recognition_threshold
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Build character list
|
||||
let characters = decoded_text
|
||||
.chars()
|
||||
.zip(recognition.character_confidences.iter())
|
||||
.map(|(ch, &conf)| Character {
|
||||
char: ch,
|
||||
confidence: conf,
|
||||
bbox: None, // Could be populated if available from model
|
||||
})
|
||||
.collect();
|
||||
|
||||
text_regions.push(TextRegion {
|
||||
bbox: detection.bbox,
|
||||
text: decoded_text,
|
||||
confidence,
|
||||
region_type,
|
||||
characters,
|
||||
});
|
||||
}
|
||||
|
||||
// Step 3: Combine results
|
||||
let combined_text = text_regions
|
||||
.iter()
|
||||
.map(|r| r.text.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join(" ");
|
||||
|
||||
let overall_confidence = if text_regions.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
text_regions.iter().map(|r| r.confidence).sum::<f32>() / text_regions.len() as f32
|
||||
};
|
||||
|
||||
let processing_time_ms = start.elapsed().as_millis() as u64;
|
||||
|
||||
debug!(
|
||||
"OCR completed in {}ms, recognized {} regions",
|
||||
processing_time_ms,
|
||||
text_regions.len()
|
||||
);
|
||||
|
||||
Ok(OcrResult {
|
||||
text: combined_text,
|
||||
confidence: overall_confidence,
|
||||
regions: text_regions,
|
||||
has_math,
|
||||
processing_time_ms,
|
||||
})
|
||||
}
|
||||
|
||||
/// Batch process multiple images
|
||||
pub async fn recognize_batch(
|
||||
&self,
|
||||
images: &[&[u8]],
|
||||
options: &OcrOptions,
|
||||
) -> Result<Vec<OcrResult>> {
|
||||
info!("Processing batch of {} images", images.len());
|
||||
let start = Instant::now();
|
||||
|
||||
// Process images in parallel using rayon
|
||||
let results: Result<Vec<OcrResult>> = images
|
||||
.iter()
|
||||
.map(|image_data| {
|
||||
// Note: In a real async implementation, we'd use tokio::spawn
|
||||
// For now, we'll use blocking since we're in a sync context
|
||||
futures::executor::block_on(self.recognize_with_options(image_data, options))
|
||||
})
|
||||
.collect();
|
||||
|
||||
info!("Batch processing completed in {:?}", start.elapsed());
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// Decode recognition output using the selected decoder
|
||||
fn decode_output(
|
||||
&self,
|
||||
recognition: &RecognitionResult,
|
||||
options: &OcrOptions,
|
||||
) -> Result<String> {
|
||||
debug!("Decoding output with {:?} decoder", options.decoder_type);
|
||||
|
||||
let decoded = match options.decoder_type {
|
||||
DecoderType::BeamSearch => {
|
||||
let decoder = BeamSearchDecoder::new(self.vocabulary.clone(), options.beam_width);
|
||||
decoder.decode(&recognition.logits)?
|
||||
}
|
||||
DecoderType::Greedy => {
|
||||
let decoder = GreedyDecoder::new(self.vocabulary.clone());
|
||||
decoder.decode(&recognition.logits)?
|
||||
}
|
||||
DecoderType::CTC => {
|
||||
let decoder = CTCDecoder::new(self.vocabulary.clone());
|
||||
decoder.decode(&recognition.logits)?
|
||||
}
|
||||
};
|
||||
|
||||
Ok(decoded)
|
||||
}
|
||||
|
||||
/// Get the current model registry
|
||||
pub fn registry(&self) -> Arc<RwLock<ModelRegistry>> {
|
||||
Arc::clone(&self.registry)
|
||||
}
|
||||
|
||||
/// Get the default options
|
||||
pub fn default_options(&self) -> &OcrOptions {
|
||||
&self.default_options
|
||||
}
|
||||
|
||||
/// Check if engine is warmed up
|
||||
pub fn is_warmed_up(&self) -> bool {
|
||||
*self.warmed_up.read()
|
||||
}
|
||||
}
|
||||
|
||||
impl OcrProcessor for OcrEngine {
|
||||
fn process(&self, image_data: &[u8], options: &OcrOptions) -> Result<OcrResult> {
|
||||
// Blocking wrapper for async method
|
||||
futures::executor::block_on(self.recognize_with_options(image_data, options))
|
||||
}
|
||||
|
||||
fn process_batch(&self, images: &[&[u8]], options: &OcrOptions) -> Result<Vec<OcrResult>> {
|
||||
// Blocking wrapper for async method
|
||||
futures::executor::block_on(self.recognize_batch(images, options))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_decoder_selection() {
|
||||
let options = OcrOptions {
|
||||
decoder_type: DecoderType::BeamSearch,
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(options.decoder_type, DecoderType::BeamSearch);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_warmup_flag() {
|
||||
let flag = Arc::new(RwLock::new(false));
|
||||
assert!(!*flag.read());
|
||||
*flag.write() = true;
|
||||
assert!(*flag.read());
|
||||
}
|
||||
}
|
||||
790
examples/scipix/src/ocr/inference.rs
Normal file
790
examples/scipix/src/ocr/inference.rs
Normal file
@@ -0,0 +1,790 @@
|
||||
//! ONNX Inference Module
|
||||
//!
|
||||
//! This module handles ONNX inference operations for text detection,
|
||||
//! character recognition, and mathematical expression recognition.
|
||||
//!
|
||||
//! # Model Requirements
|
||||
//!
|
||||
//! This module requires ONNX models to be available in the configured model directory.
|
||||
//! Without models, all inference operations will return errors.
|
||||
//!
|
||||
//! To use this module:
|
||||
//! 1. Download compatible ONNX models (PaddleOCR, TrOCR, or similar)
|
||||
//! 2. Place them in the models directory
|
||||
//! 3. Enable the `ocr` feature flag
|
||||
|
||||
use super::{models::ModelHandle, OcrError, OcrOptions, Result};
|
||||
use image::{DynamicImage, GenericImageView};
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
use ndarray::Array4;
|
||||
#[cfg(feature = "ocr")]
|
||||
use ort::value::Tensor;
|
||||
|
||||
/// Result from text detection
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DetectionResult {
|
||||
/// Bounding box [x, y, width, height]
|
||||
pub bbox: [f32; 4],
|
||||
/// Detection confidence
|
||||
pub confidence: f32,
|
||||
/// Cropped image region
|
||||
pub region_image: Vec<u8>,
|
||||
/// Whether this region likely contains math
|
||||
pub is_math_likely: bool,
|
||||
}
|
||||
|
||||
/// Result from text/math recognition
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RecognitionResult {
|
||||
/// Logits output from the model [sequence_length, vocab_size]
|
||||
pub logits: Vec<Vec<f32>>,
|
||||
/// Character-level confidence scores
|
||||
pub character_confidences: Vec<f32>,
|
||||
/// Raw output tensor (for debugging)
|
||||
pub raw_output: Option<Vec<f32>>,
|
||||
}
|
||||
|
||||
/// Inference engine for running ONNX models
|
||||
///
|
||||
/// IMPORTANT: This engine requires ONNX models to be loaded.
|
||||
/// All methods will return errors if models are not properly initialized.
|
||||
pub struct InferenceEngine {
|
||||
/// Detection model
|
||||
detection_model: Arc<ModelHandle>,
|
||||
/// Recognition model
|
||||
recognition_model: Arc<ModelHandle>,
|
||||
/// Math recognition model (optional)
|
||||
math_model: Option<Arc<ModelHandle>>,
|
||||
/// Whether to use GPU acceleration
|
||||
use_gpu: bool,
|
||||
/// Whether models are actually loaded (vs placeholder handles)
|
||||
models_loaded: bool,
|
||||
}
|
||||
|
||||
impl InferenceEngine {
|
||||
/// Create a new inference engine
|
||||
pub fn new(
|
||||
detection_model: Arc<ModelHandle>,
|
||||
recognition_model: Arc<ModelHandle>,
|
||||
math_model: Option<Arc<ModelHandle>>,
|
||||
use_gpu: bool,
|
||||
) -> Result<Self> {
|
||||
// Check if models are actually loaded with ONNX sessions
|
||||
let detection_loaded = detection_model.is_loaded();
|
||||
let recognition_loaded = recognition_model.is_loaded();
|
||||
let models_loaded = detection_loaded && recognition_loaded;
|
||||
|
||||
if !models_loaded {
|
||||
warn!(
|
||||
"ONNX models not fully loaded. Detection: {}, Recognition: {}",
|
||||
detection_loaded, recognition_loaded
|
||||
);
|
||||
warn!("OCR inference will fail until models are properly configured.");
|
||||
} else {
|
||||
info!(
|
||||
"Inference engine initialized with loaded models (GPU: {})",
|
||||
if use_gpu { "enabled" } else { "disabled" }
|
||||
);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
detection_model,
|
||||
recognition_model,
|
||||
math_model,
|
||||
use_gpu,
|
||||
models_loaded,
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if the inference engine is ready for use
|
||||
pub fn is_ready(&self) -> bool {
|
||||
self.models_loaded
|
||||
}
|
||||
|
||||
/// Run text detection on an image
|
||||
pub async fn run_detection(
|
||||
&self,
|
||||
image_data: &[u8],
|
||||
threshold: f32,
|
||||
) -> Result<Vec<DetectionResult>> {
|
||||
if !self.models_loaded {
|
||||
return Err(OcrError::ModelLoading(
|
||||
"ONNX models not loaded. Please download and configure OCR models before use. \
|
||||
See examples/scipix/docs/MODEL_SETUP.md for instructions."
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
debug!("Running text detection (threshold: {})", threshold);
|
||||
let input_tensor = self.preprocess_image_for_detection(image_data)?;
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
{
|
||||
let detections = self
|
||||
.run_onnx_detection(&input_tensor, threshold, image_data)
|
||||
.await?;
|
||||
debug!("Detected {} regions", detections.len());
|
||||
return Ok(detections);
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "ocr"))]
|
||||
{
|
||||
Err(OcrError::Inference(
|
||||
"OCR feature not enabled. Rebuild with `--features ocr` to enable ONNX inference."
|
||||
.to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Run text recognition on a region image
|
||||
pub async fn run_recognition(
|
||||
&self,
|
||||
region_image: &[u8],
|
||||
options: &OcrOptions,
|
||||
) -> Result<RecognitionResult> {
|
||||
if !self.models_loaded {
|
||||
return Err(OcrError::ModelLoading(
|
||||
"ONNX models not loaded. Please download and configure OCR models before use."
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
debug!("Running text recognition");
|
||||
let input_tensor = self.preprocess_image_for_recognition(region_image)?;
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
{
|
||||
let result = self.run_onnx_recognition(&input_tensor, options).await?;
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "ocr"))]
|
||||
{
|
||||
Err(OcrError::Inference(
|
||||
"OCR feature not enabled. Rebuild with `--features ocr` to enable ONNX inference."
|
||||
.to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Run math recognition on a region image
|
||||
pub async fn run_math_recognition(
|
||||
&self,
|
||||
region_image: &[u8],
|
||||
options: &OcrOptions,
|
||||
) -> Result<RecognitionResult> {
|
||||
if !self.models_loaded {
|
||||
return Err(OcrError::ModelLoading(
|
||||
"ONNX models not loaded. Please download and configure OCR models before use."
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
debug!("Running math recognition");
|
||||
|
||||
if self.math_model.is_none() || !self.math_model.as_ref().unwrap().is_loaded() {
|
||||
warn!("Math model not loaded, falling back to text recognition");
|
||||
return self.run_recognition(region_image, options).await;
|
||||
}
|
||||
|
||||
let input_tensor = self.preprocess_image_for_math(region_image)?;
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
{
|
||||
let result = self
|
||||
.run_onnx_math_recognition(&input_tensor, options)
|
||||
.await?;
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "ocr"))]
|
||||
{
|
||||
Err(OcrError::Inference(
|
||||
"OCR feature not enabled. Rebuild with `--features ocr` to enable ONNX inference."
|
||||
.to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Preprocess image for detection model
|
||||
fn preprocess_image_for_detection(&self, image_data: &[u8]) -> Result<Vec<f32>> {
|
||||
let img = image::load_from_memory(image_data)
|
||||
.map_err(|e| OcrError::ImageProcessing(format!("Failed to decode image: {}", e)))?;
|
||||
|
||||
let input_shape = self.detection_model.input_shape();
|
||||
let (_, _, height, width) = (
|
||||
input_shape[0],
|
||||
input_shape[1],
|
||||
input_shape[2],
|
||||
input_shape[3],
|
||||
);
|
||||
|
||||
let resized = img.resize_exact(
|
||||
width as u32,
|
||||
height as u32,
|
||||
image::imageops::FilterType::Lanczos3,
|
||||
);
|
||||
|
||||
let rgb = resized.to_rgb8();
|
||||
let mut tensor = Vec::with_capacity(3 * height * width);
|
||||
|
||||
// Convert to NCHW format with normalization
|
||||
for c in 0..3 {
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
let pixel = rgb.get_pixel(x as u32, y as u32);
|
||||
tensor.push(pixel[c] as f32 / 255.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
/// Preprocess image for recognition model
|
||||
fn preprocess_image_for_recognition(&self, image_data: &[u8]) -> Result<Vec<f32>> {
|
||||
let img = image::load_from_memory(image_data)
|
||||
.map_err(|e| OcrError::ImageProcessing(format!("Failed to decode image: {}", e)))?;
|
||||
|
||||
let input_shape = self.recognition_model.input_shape();
|
||||
let (_, channels, height, width) = (
|
||||
input_shape[0],
|
||||
input_shape[1],
|
||||
input_shape[2],
|
||||
input_shape[3],
|
||||
);
|
||||
|
||||
let resized = img.resize_exact(
|
||||
width as u32,
|
||||
height as u32,
|
||||
image::imageops::FilterType::Lanczos3,
|
||||
);
|
||||
|
||||
let mut tensor = Vec::with_capacity(channels * height * width);
|
||||
|
||||
if channels == 1 {
|
||||
let gray = resized.to_luma8();
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
let pixel = gray.get_pixel(x as u32, y as u32);
|
||||
tensor.push((pixel[0] as f32 / 127.5) - 1.0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let rgb = resized.to_rgb8();
|
||||
for c in 0..3 {
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
let pixel = rgb.get_pixel(x as u32, y as u32);
|
||||
tensor.push((pixel[c] as f32 / 127.5) - 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
/// Preprocess image for math recognition model
|
||||
fn preprocess_image_for_math(&self, image_data: &[u8]) -> Result<Vec<f32>> {
|
||||
let math_model = self
|
||||
.math_model
|
||||
.as_ref()
|
||||
.ok_or_else(|| OcrError::Inference("Math model not loaded".to_string()))?;
|
||||
|
||||
let img = image::load_from_memory(image_data)
|
||||
.map_err(|e| OcrError::ImageProcessing(format!("Failed to decode image: {}", e)))?;
|
||||
|
||||
let input_shape = math_model.input_shape();
|
||||
let (_, channels, height, width) = (
|
||||
input_shape[0],
|
||||
input_shape[1],
|
||||
input_shape[2],
|
||||
input_shape[3],
|
||||
);
|
||||
|
||||
let resized = img.resize_exact(
|
||||
width as u32,
|
||||
height as u32,
|
||||
image::imageops::FilterType::Lanczos3,
|
||||
);
|
||||
|
||||
let mut tensor = Vec::with_capacity(channels * height * width);
|
||||
|
||||
if channels == 1 {
|
||||
let gray = resized.to_luma8();
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
let pixel = gray.get_pixel(x as u32, y as u32);
|
||||
tensor.push((pixel[0] as f32 / 127.5) - 1.0);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let rgb = resized.to_rgb8();
|
||||
for c in 0..channels {
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
let pixel = rgb.get_pixel(x as u32, y as u32);
|
||||
tensor.push((pixel[c] as f32 / 127.5) - 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tensor)
|
||||
}
|
||||
|
||||
/// ONNX detection inference (requires `ocr` feature)
|
||||
#[cfg(feature = "ocr")]
|
||||
async fn run_onnx_detection(
|
||||
&self,
|
||||
input_tensor: &[f32],
|
||||
threshold: f32,
|
||||
original_image: &[u8],
|
||||
) -> Result<Vec<DetectionResult>> {
|
||||
let session_arc = self.detection_model.session().ok_or_else(|| {
|
||||
OcrError::OnnxRuntime("Detection model session not loaded".to_string())
|
||||
})?;
|
||||
let mut session = session_arc.lock();
|
||||
|
||||
let input_shape = self.detection_model.input_shape();
|
||||
let shape: Vec<usize> = input_shape.to_vec();
|
||||
|
||||
// Create tensor from input data
|
||||
let input_array = Array4::from_shape_vec(
|
||||
(shape[0], shape[1], shape[2], shape[3]),
|
||||
input_tensor.to_vec(),
|
||||
)
|
||||
.map_err(|e| OcrError::Inference(format!("Failed to create input tensor: {}", e)))?;
|
||||
|
||||
// Convert to dynamic-dimension view and create ORT tensor
|
||||
let input_dyn = input_array.into_dyn();
|
||||
let input_tensor = Tensor::from_array(input_dyn)
|
||||
.map_err(|e| OcrError::OnnxRuntime(format!("Failed to create ORT tensor: {}", e)))?;
|
||||
|
||||
// Run inference
|
||||
let outputs = session
|
||||
.run(ort::inputs![input_tensor])
|
||||
.map_err(|e| OcrError::OnnxRuntime(format!("Inference failed: {}", e)))?;
|
||||
|
||||
let output_tensor = outputs
|
||||
.iter()
|
||||
.next()
|
||||
.map(|(_, v)| v)
|
||||
.ok_or_else(|| OcrError::OnnxRuntime("No output tensor found".to_string()))?;
|
||||
|
||||
let (_, raw_data) = output_tensor
|
||||
.try_extract_tensor::<f32>()
|
||||
.map_err(|e| OcrError::OnnxRuntime(format!("Failed to extract output: {}", e)))?;
|
||||
let output_data: Vec<f32> = raw_data.to_vec();
|
||||
|
||||
let original_img = image::load_from_memory(original_image)
|
||||
.map_err(|e| OcrError::ImageProcessing(format!("Failed to decode image: {}", e)))?;
|
||||
|
||||
let detections = self.parse_detection_output(&output_data, threshold, &original_img)?;
|
||||
Ok(detections)
|
||||
}
|
||||
|
||||
/// Parse detection model output
|
||||
#[cfg(feature = "ocr")]
|
||||
fn parse_detection_output(
|
||||
&self,
|
||||
output: &[f32],
|
||||
threshold: f32,
|
||||
original_img: &DynamicImage,
|
||||
) -> Result<Vec<DetectionResult>> {
|
||||
let mut results = Vec::new();
|
||||
let output_shape = self.detection_model.output_shape();
|
||||
|
||||
if output_shape.len() >= 2 {
|
||||
let num_detections = output_shape[1];
|
||||
let detection_size = if output_shape.len() >= 3 {
|
||||
output_shape[2]
|
||||
} else {
|
||||
85
|
||||
};
|
||||
|
||||
for i in 0..num_detections {
|
||||
let base_idx = i * detection_size;
|
||||
if base_idx + 5 > output.len() {
|
||||
break;
|
||||
}
|
||||
|
||||
let confidence = output[base_idx + 4];
|
||||
if confidence < threshold {
|
||||
continue;
|
||||
}
|
||||
|
||||
let cx = output[base_idx];
|
||||
let cy = output[base_idx + 1];
|
||||
let w = output[base_idx + 2];
|
||||
let h = output[base_idx + 3];
|
||||
|
||||
let img_width = original_img.width() as f32;
|
||||
let img_height = original_img.height() as f32;
|
||||
|
||||
let x = ((cx - w / 2.0) * img_width).max(0.0);
|
||||
let y = ((cy - h / 2.0) * img_height).max(0.0);
|
||||
let width = (w * img_width).min(img_width - x);
|
||||
let height = (h * img_height).min(img_height - y);
|
||||
|
||||
if width <= 0.0 || height <= 0.0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let cropped =
|
||||
original_img.crop_imm(x as u32, y as u32, width as u32, height as u32);
|
||||
|
||||
let mut region_bytes = Vec::new();
|
||||
cropped
|
||||
.write_to(
|
||||
&mut std::io::Cursor::new(&mut region_bytes),
|
||||
image::ImageFormat::Png,
|
||||
)
|
||||
.map_err(|e| {
|
||||
OcrError::ImageProcessing(format!("Failed to encode region: {}", e))
|
||||
})?;
|
||||
|
||||
let aspect_ratio = width / height;
|
||||
let is_math_likely = aspect_ratio > 2.0 || aspect_ratio < 0.5;
|
||||
|
||||
results.push(DetectionResult {
|
||||
bbox: [x, y, width, height],
|
||||
confidence,
|
||||
region_image: region_bytes,
|
||||
is_math_likely,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// ONNX recognition inference (requires `ocr` feature)
|
||||
#[cfg(feature = "ocr")]
|
||||
async fn run_onnx_recognition(
|
||||
&self,
|
||||
input_tensor: &[f32],
|
||||
_options: &OcrOptions,
|
||||
) -> Result<RecognitionResult> {
|
||||
let session_arc = self.recognition_model.session().ok_or_else(|| {
|
||||
OcrError::OnnxRuntime("Recognition model session not loaded".to_string())
|
||||
})?;
|
||||
let mut session = session_arc.lock();
|
||||
|
||||
let input_shape = self.recognition_model.input_shape();
|
||||
let shape: Vec<usize> = input_shape.to_vec();
|
||||
|
||||
let input_array = Array4::from_shape_vec(
|
||||
(shape[0], shape[1], shape[2], shape[3]),
|
||||
input_tensor.to_vec(),
|
||||
)
|
||||
.map_err(|e| OcrError::Inference(format!("Failed to create input tensor: {}", e)))?;
|
||||
|
||||
let input_dyn = input_array.into_dyn();
|
||||
let input_ort = Tensor::from_array(input_dyn)
|
||||
.map_err(|e| OcrError::OnnxRuntime(format!("Failed to create ORT tensor: {}", e)))?;
|
||||
|
||||
let outputs = session
|
||||
.run(ort::inputs![input_ort])
|
||||
.map_err(|e| OcrError::OnnxRuntime(format!("Recognition inference failed: {}", e)))?;
|
||||
|
||||
let output_tensor = outputs
|
||||
.iter()
|
||||
.next()
|
||||
.map(|(_, v)| v)
|
||||
.ok_or_else(|| OcrError::OnnxRuntime("No output tensor found".to_string()))?;
|
||||
|
||||
let (_, raw_data) = output_tensor
|
||||
.try_extract_tensor::<f32>()
|
||||
.map_err(|e| OcrError::OnnxRuntime(format!("Failed to extract output: {}", e)))?;
|
||||
let output_data: Vec<f32> = raw_data.to_vec();
|
||||
|
||||
let output_shape = self.recognition_model.output_shape();
|
||||
let seq_len = output_shape.get(1).copied().unwrap_or(26);
|
||||
let vocab_size = output_shape.get(2).copied().unwrap_or(37);
|
||||
|
||||
let mut logits = Vec::new();
|
||||
let mut character_confidences = Vec::new();
|
||||
|
||||
for i in 0..seq_len {
|
||||
let start_idx = i * vocab_size;
|
||||
let end_idx = start_idx + vocab_size;
|
||||
|
||||
if end_idx <= output_data.len() {
|
||||
let step_logits: Vec<f32> = output_data[start_idx..end_idx].to_vec();
|
||||
|
||||
let max_logit = step_logits
|
||||
.iter()
|
||||
.cloned()
|
||||
.fold(f32::NEG_INFINITY, f32::max);
|
||||
let exp_sum: f32 = step_logits.iter().map(|&x| (x - max_logit).exp()).sum();
|
||||
|
||||
let softmax: Vec<f32> = step_logits
|
||||
.iter()
|
||||
.map(|&x| (x - max_logit).exp() / exp_sum)
|
||||
.collect();
|
||||
|
||||
let max_confidence = softmax.iter().cloned().fold(0.0f32, f32::max);
|
||||
character_confidences.push(max_confidence);
|
||||
logits.push(step_logits);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(RecognitionResult {
|
||||
logits,
|
||||
character_confidences,
|
||||
raw_output: Some(output_data),
|
||||
})
|
||||
}
|
||||
|
||||
/// ONNX math recognition inference (requires `ocr` feature)
|
||||
#[cfg(feature = "ocr")]
|
||||
async fn run_onnx_math_recognition(
|
||||
&self,
|
||||
input_tensor: &[f32],
|
||||
_options: &OcrOptions,
|
||||
) -> Result<RecognitionResult> {
|
||||
let math_model = self
|
||||
.math_model
|
||||
.as_ref()
|
||||
.ok_or_else(|| OcrError::Inference("Math model not loaded".to_string()))?;
|
||||
|
||||
let session_arc = math_model
|
||||
.session()
|
||||
.ok_or_else(|| OcrError::OnnxRuntime("Math model session not loaded".to_string()))?;
|
||||
let mut session = session_arc.lock();
|
||||
|
||||
let input_shape = math_model.input_shape();
|
||||
let shape: Vec<usize> = input_shape.to_vec();
|
||||
|
||||
let input_array = Array4::from_shape_vec(
|
||||
(shape[0], shape[1], shape[2], shape[3]),
|
||||
input_tensor.to_vec(),
|
||||
)
|
||||
.map_err(|e| OcrError::Inference(format!("Failed to create input tensor: {}", e)))?;
|
||||
|
||||
let input_dyn = input_array.into_dyn();
|
||||
let input_ort = Tensor::from_array(input_dyn)
|
||||
.map_err(|e| OcrError::OnnxRuntime(format!("Failed to create ORT tensor: {}", e)))?;
|
||||
|
||||
let outputs = session.run(ort::inputs![input_ort]).map_err(|e| {
|
||||
OcrError::OnnxRuntime(format!("Math recognition inference failed: {}", e))
|
||||
})?;
|
||||
|
||||
let output_tensor = outputs
|
||||
.iter()
|
||||
.next()
|
||||
.map(|(_, v)| v)
|
||||
.ok_or_else(|| OcrError::OnnxRuntime("No output tensor found".to_string()))?;
|
||||
|
||||
let (_, raw_data) = output_tensor
|
||||
.try_extract_tensor::<f32>()
|
||||
.map_err(|e| OcrError::OnnxRuntime(format!("Failed to extract output: {}", e)))?;
|
||||
let output_data: Vec<f32> = raw_data.to_vec();
|
||||
|
||||
let output_shape = math_model.output_shape();
|
||||
let seq_len = output_shape.get(1).copied().unwrap_or(50);
|
||||
let vocab_size = output_shape.get(2).copied().unwrap_or(512);
|
||||
|
||||
let mut logits = Vec::new();
|
||||
let mut character_confidences = Vec::new();
|
||||
|
||||
for i in 0..seq_len {
|
||||
let start_idx = i * vocab_size;
|
||||
let end_idx = start_idx + vocab_size;
|
||||
|
||||
if end_idx <= output_data.len() {
|
||||
let step_logits: Vec<f32> = output_data[start_idx..end_idx].to_vec();
|
||||
|
||||
let max_logit = step_logits
|
||||
.iter()
|
||||
.cloned()
|
||||
.fold(f32::NEG_INFINITY, f32::max);
|
||||
let exp_sum: f32 = step_logits.iter().map(|&x| (x - max_logit).exp()).sum();
|
||||
|
||||
let softmax: Vec<f32> = step_logits
|
||||
.iter()
|
||||
.map(|&x| (x - max_logit).exp() / exp_sum)
|
||||
.collect();
|
||||
|
||||
let max_confidence = softmax.iter().cloned().fold(0.0f32, f32::max);
|
||||
character_confidences.push(max_confidence);
|
||||
logits.push(step_logits);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(RecognitionResult {
|
||||
logits,
|
||||
character_confidences,
|
||||
raw_output: Some(output_data),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get detection model
|
||||
pub fn detection_model(&self) -> &ModelHandle {
|
||||
&self.detection_model
|
||||
}
|
||||
|
||||
/// Get recognition model
|
||||
pub fn recognition_model(&self) -> &ModelHandle {
|
||||
&self.recognition_model
|
||||
}
|
||||
|
||||
/// Get math model if available
|
||||
pub fn math_model(&self) -> Option<&ModelHandle> {
|
||||
self.math_model.as_ref().map(|m| m.as_ref())
|
||||
}
|
||||
|
||||
/// Check if GPU acceleration is enabled
|
||||
pub fn is_gpu_enabled(&self) -> bool {
|
||||
self.use_gpu
|
||||
}
|
||||
}
|
||||
|
||||
/// Batch inference optimization
|
||||
impl InferenceEngine {
|
||||
/// Run batch detection on multiple images
|
||||
pub async fn run_batch_detection(
|
||||
&self,
|
||||
images: &[&[u8]],
|
||||
threshold: f32,
|
||||
) -> Result<Vec<Vec<DetectionResult>>> {
|
||||
if !self.models_loaded {
|
||||
return Err(OcrError::ModelLoading(
|
||||
"ONNX models not loaded. Cannot run batch detection.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
debug!("Running batch detection on {} images", images.len());
|
||||
|
||||
let mut results = Vec::new();
|
||||
for image in images {
|
||||
let detections = self.run_detection(image, threshold).await?;
|
||||
results.push(detections);
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
/// Run batch recognition on multiple regions
|
||||
pub async fn run_batch_recognition(
|
||||
&self,
|
||||
regions: &[&[u8]],
|
||||
options: &OcrOptions,
|
||||
) -> Result<Vec<RecognitionResult>> {
|
||||
if !self.models_loaded {
|
||||
return Err(OcrError::ModelLoading(
|
||||
"ONNX models not loaded. Cannot run batch recognition.".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
debug!("Running batch recognition on {} regions", regions.len());
|
||||
|
||||
let mut results = Vec::new();
|
||||
for region in regions {
|
||||
let result = self.run_recognition(region, options).await?;
|
||||
results.push(result);
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::ocr::models::{ModelMetadata, ModelType};
|
||||
use std::path::PathBuf;
|
||||
|
||||
fn create_test_model(model_type: ModelType, path: PathBuf) -> Arc<ModelHandle> {
|
||||
let metadata = ModelMetadata {
|
||||
name: format!("{:?} Model", model_type),
|
||||
version: "1.0.0".to_string(),
|
||||
input_shape: vec![1, 3, 640, 640],
|
||||
output_shape: vec![1, 100, 85],
|
||||
input_dtype: "float32".to_string(),
|
||||
file_size: 1000,
|
||||
checksum: None,
|
||||
};
|
||||
|
||||
Arc::new(ModelHandle::new(model_type, path, metadata).unwrap())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inference_engine_creation_without_models() {
|
||||
let detection = create_test_model(
|
||||
ModelType::Detection,
|
||||
PathBuf::from("/nonexistent/model.onnx"),
|
||||
);
|
||||
let recognition = create_test_model(
|
||||
ModelType::Recognition,
|
||||
PathBuf::from("/nonexistent/model.onnx"),
|
||||
);
|
||||
|
||||
let engine = InferenceEngine::new(detection, recognition, None, false).unwrap();
|
||||
assert!(!engine.is_ready());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_detection_fails_without_models() {
|
||||
let detection = create_test_model(
|
||||
ModelType::Detection,
|
||||
PathBuf::from("/nonexistent/model.onnx"),
|
||||
);
|
||||
let recognition = create_test_model(
|
||||
ModelType::Recognition,
|
||||
PathBuf::from("/nonexistent/model.onnx"),
|
||||
);
|
||||
let engine = InferenceEngine::new(detection, recognition, None, false).unwrap();
|
||||
|
||||
let png_data = create_test_png();
|
||||
let result = engine.run_detection(&png_data, 0.5).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result.unwrap_err(), OcrError::ModelLoading(_)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_recognition_fails_without_models() {
|
||||
let detection = create_test_model(
|
||||
ModelType::Detection,
|
||||
PathBuf::from("/nonexistent/model.onnx"),
|
||||
);
|
||||
let recognition = create_test_model(
|
||||
ModelType::Recognition,
|
||||
PathBuf::from("/nonexistent/model.onnx"),
|
||||
);
|
||||
let engine = InferenceEngine::new(detection, recognition, None, false).unwrap();
|
||||
|
||||
let png_data = create_test_png();
|
||||
let options = OcrOptions::default();
|
||||
let result = engine.run_recognition(&png_data, &options).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result.unwrap_err(), OcrError::ModelLoading(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_ready_reflects_model_state() {
|
||||
let detection = create_test_model(ModelType::Detection, PathBuf::from("/fake/path"));
|
||||
let recognition = create_test_model(ModelType::Recognition, PathBuf::from("/fake/path"));
|
||||
let engine = InferenceEngine::new(detection, recognition, None, false).unwrap();
|
||||
assert!(!engine.is_ready());
|
||||
}
|
||||
|
||||
fn create_test_png() -> Vec<u8> {
|
||||
use image::{ImageBuffer, RgbImage};
|
||||
let img: RgbImage = ImageBuffer::from_fn(10, 10, |_, _| image::Rgb([255, 255, 255]));
|
||||
let mut bytes: Vec<u8> = Vec::new();
|
||||
img.write_to(
|
||||
&mut std::io::Cursor::new(&mut bytes),
|
||||
image::ImageFormat::Png,
|
||||
)
|
||||
.unwrap();
|
||||
bytes
|
||||
}
|
||||
}
|
||||
235
examples/scipix/src/ocr/mod.rs
Normal file
235
examples/scipix/src/ocr/mod.rs
Normal file
@@ -0,0 +1,235 @@
|
||||
//! OCR Engine Module
|
||||
//!
|
||||
//! This module provides optical character recognition capabilities for the ruvector-scipix system.
|
||||
//! It supports text detection, character recognition, and mathematical expression recognition using
|
||||
//! ONNX models for high-performance inference.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! The OCR module is organized into several submodules:
|
||||
//! - `engine`: Main OcrEngine for orchestrating OCR operations
|
||||
//! - `models`: Model management, loading, and caching
|
||||
//! - `inference`: ONNX inference operations for detection and recognition
|
||||
//! - `decoder`: Output decoding strategies (beam search, greedy, CTC)
|
||||
//! - `confidence`: Confidence scoring and calibration
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```no_run
|
||||
//! use ruvector_scipix::ocr::{OcrEngine, OcrOptions};
|
||||
//!
|
||||
//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
|
||||
//! // Initialize the OCR engine
|
||||
//! let engine = OcrEngine::new().await?;
|
||||
//!
|
||||
//! // Load an image
|
||||
//! let image_data = std::fs::read("math_formula.png")?;
|
||||
//!
|
||||
//! // Perform OCR
|
||||
//! let result = engine.recognize(&image_data).await?;
|
||||
//!
|
||||
//! println!("Recognized text: {}", result.text);
|
||||
//! println!("Confidence: {:.2}%", result.confidence * 100.0);
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
|
||||
// Submodules
|
||||
mod confidence;
|
||||
mod decoder;
|
||||
mod engine;
|
||||
mod inference;
|
||||
mod models;
|
||||
|
||||
// Public exports
|
||||
pub use confidence::{aggregate_confidence, calculate_confidence, ConfidenceCalibrator};
|
||||
pub use decoder::{BeamSearchDecoder, CTCDecoder, Decoder, GreedyDecoder, Vocabulary};
|
||||
pub use engine::{OcrEngine, OcrProcessor};
|
||||
pub use inference::{DetectionResult, InferenceEngine, RecognitionResult};
|
||||
pub use models::{ModelHandle, ModelRegistry};
|
||||
|
||||
/// OCR processing options
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OcrOptions {
|
||||
/// Detection threshold for text regions (0.0-1.0)
|
||||
pub detection_threshold: f32,
|
||||
|
||||
/// Recognition confidence threshold (0.0-1.0)
|
||||
pub recognition_threshold: f32,
|
||||
|
||||
/// Enable mathematical expression recognition
|
||||
pub enable_math: bool,
|
||||
|
||||
/// Decoder type to use
|
||||
pub decoder_type: DecoderType,
|
||||
|
||||
/// Beam width for beam search decoder
|
||||
pub beam_width: usize,
|
||||
|
||||
/// Maximum batch size for inference
|
||||
pub batch_size: usize,
|
||||
|
||||
/// Enable GPU acceleration if available
|
||||
pub use_gpu: bool,
|
||||
|
||||
/// Language hints for recognition
|
||||
pub languages: Vec<String>,
|
||||
}
|
||||
|
||||
impl Default for OcrOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
detection_threshold: 0.5,
|
||||
recognition_threshold: 0.6,
|
||||
enable_math: true,
|
||||
decoder_type: DecoderType::BeamSearch,
|
||||
beam_width: 5,
|
||||
batch_size: 1,
|
||||
use_gpu: false,
|
||||
languages: vec!["en".to_string()],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Decoder type selection
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum DecoderType {
|
||||
/// Beam search decoder (higher quality, slower)
|
||||
BeamSearch,
|
||||
/// Greedy decoder (faster, lower quality)
|
||||
Greedy,
|
||||
/// CTC decoder for sequence-to-sequence models
|
||||
CTC,
|
||||
}
|
||||
|
||||
/// OCR result containing recognized text and metadata
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OcrResult {
|
||||
/// Recognized text
|
||||
pub text: String,
|
||||
|
||||
/// Overall confidence score (0.0-1.0)
|
||||
pub confidence: f32,
|
||||
|
||||
/// Detected text regions with their bounding boxes
|
||||
pub regions: Vec<TextRegion>,
|
||||
|
||||
/// Whether mathematical expressions were detected
|
||||
pub has_math: bool,
|
||||
|
||||
/// Processing time in milliseconds
|
||||
pub processing_time_ms: u64,
|
||||
}
|
||||
|
||||
/// A detected text region with position and content
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TextRegion {
|
||||
/// Bounding box coordinates [x, y, width, height]
|
||||
pub bbox: [f32; 4],
|
||||
|
||||
/// Recognized text in this region
|
||||
pub text: String,
|
||||
|
||||
/// Confidence score for this region (0.0-1.0)
|
||||
pub confidence: f32,
|
||||
|
||||
/// Region type (text, math, etc.)
|
||||
pub region_type: RegionType,
|
||||
|
||||
/// Character-level details if available
|
||||
pub characters: Vec<Character>,
|
||||
}
|
||||
|
||||
/// Type of text region
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum RegionType {
|
||||
/// Regular text
|
||||
Text,
|
||||
/// Mathematical expression
|
||||
Math,
|
||||
/// Diagram or figure
|
||||
Diagram,
|
||||
/// Table
|
||||
Table,
|
||||
/// Unknown type
|
||||
Unknown,
|
||||
}
|
||||
|
||||
/// Individual character with position and confidence
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Character {
|
||||
/// The character
|
||||
pub char: char,
|
||||
|
||||
/// Confidence score (0.0-1.0)
|
||||
pub confidence: f32,
|
||||
|
||||
/// Bounding box if available
|
||||
pub bbox: Option<[f32; 4]>,
|
||||
}
|
||||
|
||||
/// Error types for OCR operations
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum OcrError {
|
||||
#[error("Model loading error: {0}")]
|
||||
ModelLoading(String),
|
||||
|
||||
#[error("Inference error: {0}")]
|
||||
Inference(String),
|
||||
|
||||
#[error("Image processing error: {0}")]
|
||||
ImageProcessing(String),
|
||||
|
||||
#[error("Decoding error: {0}")]
|
||||
Decoding(String),
|
||||
|
||||
#[error("Invalid configuration: {0}")]
|
||||
InvalidConfig(String),
|
||||
|
||||
#[error("I/O error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("ONNX Runtime error: {0}")]
|
||||
OnnxRuntime(String),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, OcrError>;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ocr_options_default() {
|
||||
let options = OcrOptions::default();
|
||||
assert_eq!(options.detection_threshold, 0.5);
|
||||
assert_eq!(options.recognition_threshold, 0.6);
|
||||
assert!(options.enable_math);
|
||||
assert_eq!(options.decoder_type, DecoderType::BeamSearch);
|
||||
assert_eq!(options.beam_width, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_text_region_creation() {
|
||||
let region = TextRegion {
|
||||
bbox: [10.0, 20.0, 100.0, 30.0],
|
||||
text: "Test".to_string(),
|
||||
confidence: 0.95,
|
||||
region_type: RegionType::Text,
|
||||
characters: vec![],
|
||||
};
|
||||
assert_eq!(region.bbox[0], 10.0);
|
||||
assert_eq!(region.text, "Test");
|
||||
assert_eq!(region.region_type, RegionType::Text);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decoder_type_equality() {
|
||||
assert_eq!(DecoderType::BeamSearch, DecoderType::BeamSearch);
|
||||
assert_ne!(DecoderType::BeamSearch, DecoderType::Greedy);
|
||||
assert_ne!(DecoderType::Greedy, DecoderType::CTC);
|
||||
}
|
||||
}
|
||||
373
examples/scipix/src/ocr/models.rs
Normal file
373
examples/scipix/src/ocr/models.rs
Normal file
@@ -0,0 +1,373 @@
|
||||
//! Model Management Module
|
||||
//!
|
||||
//! This module handles loading, caching, and managing ONNX models for OCR.
|
||||
//! It supports lazy loading, model downloading with progress tracking,
|
||||
//! and checksum verification.
|
||||
|
||||
use super::{OcrError, Result};
|
||||
use dashmap::DashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
use ort::session::Session;
|
||||
#[cfg(feature = "ocr")]
|
||||
use parking_lot::Mutex;
|
||||
|
||||
/// Model types supported by the OCR engine
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum ModelType {
|
||||
/// Text detection model (finds text regions in images)
|
||||
Detection,
|
||||
/// Text recognition model (recognizes characters in regions)
|
||||
Recognition,
|
||||
/// Math expression recognition model
|
||||
Math,
|
||||
}
|
||||
|
||||
/// Handle to a loaded ONNX model
|
||||
#[derive(Clone)]
|
||||
pub struct ModelHandle {
|
||||
/// Model type
|
||||
model_type: ModelType,
|
||||
/// Path to the model file
|
||||
path: PathBuf,
|
||||
/// Model metadata
|
||||
metadata: ModelMetadata,
|
||||
/// ONNX Runtime session (when ocr feature is enabled)
|
||||
/// Wrapped in Mutex for mutable access required by ort 2.0 Session::run
|
||||
#[cfg(feature = "ocr")]
|
||||
session: Option<Arc<Mutex<Session>>>,
|
||||
/// Mock session for when ocr feature is disabled
|
||||
#[cfg(not(feature = "ocr"))]
|
||||
#[allow(dead_code)]
|
||||
session: Option<()>,
|
||||
}
|
||||
|
||||
impl ModelHandle {
|
||||
/// Create a new model handle
|
||||
pub fn new(model_type: ModelType, path: PathBuf, metadata: ModelMetadata) -> Result<Self> {
|
||||
debug!("Creating model handle for {:?} at {:?}", model_type, path);
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
let session = if path.exists() {
|
||||
match Session::builder() {
|
||||
Ok(builder) => match builder.commit_from_file(&path) {
|
||||
Ok(session) => {
|
||||
info!("Successfully loaded ONNX model: {:?}", path);
|
||||
Some(Arc::new(Mutex::new(session)))
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to load ONNX model {:?}: {}", path, e);
|
||||
None
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
warn!("Failed to create ONNX session builder: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
debug!("Model file not found: {:?}", path);
|
||||
None
|
||||
};
|
||||
|
||||
#[cfg(not(feature = "ocr"))]
|
||||
let session: Option<()> = None;
|
||||
|
||||
Ok(Self {
|
||||
model_type,
|
||||
path,
|
||||
metadata,
|
||||
session,
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if the model session is loaded
|
||||
pub fn is_loaded(&self) -> bool {
|
||||
self.session.is_some()
|
||||
}
|
||||
|
||||
/// Get the ONNX session (only available with ocr feature)
|
||||
#[cfg(feature = "ocr")]
|
||||
pub fn session(&self) -> Option<&Arc<Mutex<Session>>> {
|
||||
self.session.as_ref()
|
||||
}
|
||||
|
||||
/// Get the model type
|
||||
pub fn model_type(&self) -> ModelType {
|
||||
self.model_type
|
||||
}
|
||||
|
||||
/// Get the model path
|
||||
pub fn path(&self) -> &Path {
|
||||
&self.path
|
||||
}
|
||||
|
||||
/// Get model metadata
|
||||
pub fn metadata(&self) -> &ModelMetadata {
|
||||
&self.metadata
|
||||
}
|
||||
|
||||
/// Get input shape for the model
|
||||
pub fn input_shape(&self) -> &[usize] {
|
||||
&self.metadata.input_shape
|
||||
}
|
||||
|
||||
/// Get output shape for the model
|
||||
pub fn output_shape(&self) -> &[usize] {
|
||||
&self.metadata.output_shape
|
||||
}
|
||||
}
|
||||
|
||||
/// Model metadata
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelMetadata {
|
||||
/// Model name
|
||||
pub name: String,
|
||||
/// Model version
|
||||
pub version: String,
|
||||
/// Input tensor shape
|
||||
pub input_shape: Vec<usize>,
|
||||
/// Output tensor shape
|
||||
pub output_shape: Vec<usize>,
|
||||
/// Expected input data type
|
||||
pub input_dtype: String,
|
||||
/// File size in bytes
|
||||
pub file_size: u64,
|
||||
/// SHA256 checksum
|
||||
pub checksum: Option<String>,
|
||||
}
|
||||
|
||||
/// Model registry for loading and caching models
|
||||
pub struct ModelRegistry {
|
||||
/// Cache of loaded models
|
||||
cache: DashMap<ModelType, Arc<ModelHandle>>,
|
||||
/// Base directory for models
|
||||
model_dir: PathBuf,
|
||||
/// Whether to enable lazy loading
|
||||
lazy_loading: bool,
|
||||
}
|
||||
|
||||
impl ModelRegistry {
|
||||
/// Create a new model registry
|
||||
pub fn new() -> Self {
|
||||
Self::with_model_dir(PathBuf::from("./models"))
|
||||
}
|
||||
|
||||
/// Create a new model registry with custom model directory
|
||||
pub fn with_model_dir(model_dir: PathBuf) -> Self {
|
||||
info!("Initializing model registry at {:?}", model_dir);
|
||||
Self {
|
||||
cache: DashMap::new(),
|
||||
model_dir,
|
||||
lazy_loading: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Load the detection model
|
||||
pub async fn load_detection_model(&mut self) -> Result<Arc<ModelHandle>> {
|
||||
self.load_model(ModelType::Detection).await
|
||||
}
|
||||
|
||||
/// Load the recognition model
|
||||
pub async fn load_recognition_model(&mut self) -> Result<Arc<ModelHandle>> {
|
||||
self.load_model(ModelType::Recognition).await
|
||||
}
|
||||
|
||||
/// Load the math recognition model
|
||||
pub async fn load_math_model(&mut self) -> Result<Arc<ModelHandle>> {
|
||||
self.load_model(ModelType::Math).await
|
||||
}
|
||||
|
||||
/// Load a model by type
|
||||
pub async fn load_model(&mut self, model_type: ModelType) -> Result<Arc<ModelHandle>> {
|
||||
// Check cache first
|
||||
if let Some(handle) = self.cache.get(&model_type) {
|
||||
debug!("Model {:?} found in cache", model_type);
|
||||
return Ok(Arc::clone(handle.value()));
|
||||
}
|
||||
|
||||
info!("Loading model {:?}...", model_type);
|
||||
|
||||
// Get model path
|
||||
let model_path = self.get_model_path(model_type);
|
||||
|
||||
// Check if model exists
|
||||
if !model_path.exists() {
|
||||
if self.lazy_loading {
|
||||
warn!(
|
||||
"Model {:?} not found at {:?}. OCR will not work without models.",
|
||||
model_type, model_path
|
||||
);
|
||||
warn!("Download models from: https://github.com/PaddlePaddle/PaddleOCR or configure custom models.");
|
||||
} else {
|
||||
return Err(OcrError::ModelLoading(format!(
|
||||
"Model {:?} not found at {:?}",
|
||||
model_type, model_path
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
// Load model metadata
|
||||
let metadata = self.get_model_metadata(model_type);
|
||||
|
||||
// Verify checksum if provided
|
||||
if let Some(ref checksum) = metadata.checksum {
|
||||
if model_path.exists() {
|
||||
debug!("Verifying model checksum: {}", checksum);
|
||||
// In production: verify_checksum(&model_path, checksum)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Create model handle (will load ONNX session if file exists)
|
||||
let handle = Arc::new(ModelHandle::new(model_type, model_path, metadata)?);
|
||||
|
||||
// Cache the handle
|
||||
self.cache.insert(model_type, Arc::clone(&handle));
|
||||
|
||||
if handle.is_loaded() {
|
||||
info!(
|
||||
"Model {:?} loaded successfully with ONNX session",
|
||||
model_type
|
||||
);
|
||||
} else {
|
||||
warn!(
|
||||
"Model {:?} handle created but ONNX session not loaded",
|
||||
model_type
|
||||
);
|
||||
}
|
||||
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
/// Get the file path for a model type
|
||||
fn get_model_path(&self, model_type: ModelType) -> PathBuf {
|
||||
let filename = match model_type {
|
||||
ModelType::Detection => "text_detection.onnx",
|
||||
ModelType::Recognition => "text_recognition.onnx",
|
||||
ModelType::Math => "math_recognition.onnx",
|
||||
};
|
||||
self.model_dir.join(filename)
|
||||
}
|
||||
|
||||
/// Get default metadata for a model type
|
||||
fn get_model_metadata(&self, model_type: ModelType) -> ModelMetadata {
|
||||
match model_type {
|
||||
ModelType::Detection => ModelMetadata {
|
||||
name: "Text Detection".to_string(),
|
||||
version: "1.0.0".to_string(),
|
||||
input_shape: vec![1, 3, 640, 640], // NCHW format
|
||||
output_shape: vec![1, 25200, 85], // Detections
|
||||
input_dtype: "float32".to_string(),
|
||||
file_size: 50_000_000, // ~50MB
|
||||
checksum: None,
|
||||
},
|
||||
ModelType::Recognition => ModelMetadata {
|
||||
name: "Text Recognition".to_string(),
|
||||
version: "1.0.0".to_string(),
|
||||
input_shape: vec![1, 1, 32, 128], // NCHW format
|
||||
output_shape: vec![1, 26, 37], // Sequence length, vocab size
|
||||
input_dtype: "float32".to_string(),
|
||||
file_size: 20_000_000, // ~20MB
|
||||
checksum: None,
|
||||
},
|
||||
ModelType::Math => ModelMetadata {
|
||||
name: "Math Recognition".to_string(),
|
||||
version: "1.0.0".to_string(),
|
||||
input_shape: vec![1, 1, 64, 256], // NCHW format
|
||||
output_shape: vec![1, 50, 512], // Sequence length, vocab size
|
||||
input_dtype: "float32".to_string(),
|
||||
file_size: 80_000_000, // ~80MB
|
||||
checksum: None,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear the model cache
|
||||
pub fn clear_cache(&mut self) {
|
||||
info!("Clearing model cache");
|
||||
self.cache.clear();
|
||||
}
|
||||
|
||||
/// Get a cached model if available
|
||||
pub fn get_cached(&self, model_type: ModelType) -> Option<Arc<ModelHandle>> {
|
||||
self.cache.get(&model_type).map(|h| Arc::clone(h.value()))
|
||||
}
|
||||
|
||||
/// Set lazy loading mode
|
||||
pub fn set_lazy_loading(&mut self, enabled: bool) {
|
||||
self.lazy_loading = enabled;
|
||||
}
|
||||
|
||||
/// Get the model directory
|
||||
pub fn model_dir(&self) -> &Path {
|
||||
&self.model_dir
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ModelRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_model_registry_creation() {
|
||||
let registry = ModelRegistry::new();
|
||||
assert_eq!(registry.model_dir(), Path::new("./models"));
|
||||
assert!(registry.lazy_loading);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_path_generation() {
|
||||
let registry = ModelRegistry::new();
|
||||
let path = registry.get_model_path(ModelType::Detection);
|
||||
assert!(path.to_string_lossy().contains("text_detection.onnx"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_metadata() {
|
||||
let registry = ModelRegistry::new();
|
||||
let metadata = registry.get_model_metadata(ModelType::Recognition);
|
||||
assert_eq!(metadata.name, "Text Recognition");
|
||||
assert_eq!(metadata.version, "1.0.0");
|
||||
assert_eq!(metadata.input_shape, vec![1, 1, 32, 128]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_model_caching() {
|
||||
let mut registry = ModelRegistry::new();
|
||||
let model1 = registry.load_detection_model().await.unwrap();
|
||||
let model2 = registry.load_detection_model().await.unwrap();
|
||||
assert!(Arc::ptr_eq(&model1, &model2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear_cache() {
|
||||
let mut registry = ModelRegistry::new();
|
||||
registry.clear_cache();
|
||||
assert_eq!(registry.cache.len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_handle_without_file() {
|
||||
let path = PathBuf::from("/nonexistent/model.onnx");
|
||||
let metadata = ModelMetadata {
|
||||
name: "Test".to_string(),
|
||||
version: "1.0.0".to_string(),
|
||||
input_shape: vec![1, 3, 640, 640],
|
||||
output_shape: vec![1, 100, 85],
|
||||
input_dtype: "float32".to_string(),
|
||||
file_size: 1000,
|
||||
checksum: None,
|
||||
};
|
||||
let handle = ModelHandle::new(ModelType::Detection, path, metadata).unwrap();
|
||||
assert!(!handle.is_loaded());
|
||||
}
|
||||
}
|
||||
396
examples/scipix/src/optimize/batch.rs
Normal file
396
examples/scipix/src/optimize/batch.rs
Normal file
@@ -0,0 +1,396 @@
|
||||
//! Dynamic batching for throughput optimization
|
||||
//!
|
||||
//! Provides intelligent batching to maximize GPU/CPU utilization while
|
||||
//! maintaining acceptable latency.
|
||||
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::{oneshot, Mutex};
|
||||
use tokio::time::sleep;
|
||||
|
||||
/// Item in the batching queue
|
||||
pub struct BatchItem<T, R> {
|
||||
pub data: T,
|
||||
pub response: oneshot::Sender<BatchResult<R>>,
|
||||
pub enqueued_at: Instant,
|
||||
}
|
||||
|
||||
/// Result of batch processing
|
||||
pub type BatchResult<T> = std::result::Result<T, BatchError>;
|
||||
|
||||
/// Batch processing errors
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum BatchError {
|
||||
Timeout,
|
||||
ProcessingFailed(String),
|
||||
QueueFull,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for BatchError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
BatchError::Timeout => write!(f, "Batch processing timeout"),
|
||||
BatchError::ProcessingFailed(msg) => write!(f, "Processing failed: {}", msg),
|
||||
BatchError::QueueFull => write!(f, "Queue is full"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for BatchError {}
|
||||
|
||||
/// Dynamic batcher configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BatchConfig {
|
||||
/// Maximum items in a batch
|
||||
pub max_batch_size: usize,
|
||||
/// Maximum time to wait before processing partial batch
|
||||
pub max_wait_ms: u64,
|
||||
/// Maximum queue size
|
||||
pub max_queue_size: usize,
|
||||
/// Minimum batch size to prefer
|
||||
pub preferred_batch_size: usize,
|
||||
}
|
||||
|
||||
impl Default for BatchConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_batch_size: 32,
|
||||
max_wait_ms: 50,
|
||||
max_queue_size: 1000,
|
||||
preferred_batch_size: 16,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Dynamic batcher for throughput optimization
|
||||
pub struct DynamicBatcher<T, R> {
|
||||
config: BatchConfig,
|
||||
queue: Arc<Mutex<VecDeque<BatchItem<T, R>>>>,
|
||||
processor: Arc<dyn Fn(Vec<T>) -> Vec<std::result::Result<R, String>> + Send + Sync>,
|
||||
shutdown: Arc<Mutex<bool>>,
|
||||
}
|
||||
|
||||
impl<T, R> DynamicBatcher<T, R>
|
||||
where
|
||||
T: Send + 'static,
|
||||
R: Send + 'static,
|
||||
{
|
||||
/// Create new dynamic batcher
|
||||
pub fn new<F>(config: BatchConfig, processor: F) -> Self
|
||||
where
|
||||
F: Fn(Vec<T>) -> Vec<std::result::Result<R, String>> + Send + Sync + 'static,
|
||||
{
|
||||
Self {
|
||||
config,
|
||||
queue: Arc::new(Mutex::new(VecDeque::new())),
|
||||
processor: Arc::new(processor),
|
||||
shutdown: Arc::new(Mutex::new(false)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Add item to batch queue
|
||||
pub async fn add(&self, item: T) -> BatchResult<R> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
let batch_item = BatchItem {
|
||||
data: item,
|
||||
response: tx,
|
||||
enqueued_at: Instant::now(),
|
||||
};
|
||||
|
||||
{
|
||||
let mut queue = self.queue.lock().await;
|
||||
if queue.len() >= self.config.max_queue_size {
|
||||
return Err(BatchError::QueueFull);
|
||||
}
|
||||
queue.push_back(batch_item);
|
||||
}
|
||||
|
||||
// Wait for response
|
||||
rx.await.map_err(|_| BatchError::Timeout)?
|
||||
}
|
||||
|
||||
/// Start batch processing loop
|
||||
pub async fn run(&self) {
|
||||
let mut last_process = Instant::now();
|
||||
|
||||
loop {
|
||||
// Check if shutdown requested
|
||||
{
|
||||
let shutdown = self.shutdown.lock().await;
|
||||
if *shutdown {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let should_process = {
|
||||
let queue = self.queue.lock().await;
|
||||
queue.len() >= self.config.max_batch_size
|
||||
|| (queue.len() >= self.config.preferred_batch_size
|
||||
&& last_process.elapsed().as_millis() >= self.config.max_wait_ms as u128)
|
||||
|| (queue.len() > 0
|
||||
&& last_process.elapsed().as_millis() >= self.config.max_wait_ms as u128)
|
||||
};
|
||||
|
||||
if should_process {
|
||||
self.process_batch().await;
|
||||
last_process = Instant::now();
|
||||
} else {
|
||||
// Sleep briefly to avoid busy waiting
|
||||
sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
}
|
||||
|
||||
// Process remaining items before shutdown
|
||||
self.process_batch().await;
|
||||
}
|
||||
|
||||
/// Process current batch
|
||||
async fn process_batch(&self) {
|
||||
let items = {
|
||||
let mut queue = self.queue.lock().await;
|
||||
let batch_size = self.config.max_batch_size.min(queue.len());
|
||||
if batch_size == 0 {
|
||||
return;
|
||||
}
|
||||
queue.drain(..batch_size).collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
if items.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Extract data and response channels
|
||||
let (data, responses): (Vec<_>, Vec<_>) = items
|
||||
.into_iter()
|
||||
.map(|item| (item.data, item.response))
|
||||
.unzip();
|
||||
|
||||
// Process batch
|
||||
let results = (self.processor)(data);
|
||||
|
||||
// Send responses
|
||||
for (response_tx, result) in responses.into_iter().zip(results.into_iter()) {
|
||||
let batch_result = result.map_err(|e| BatchError::ProcessingFailed(e));
|
||||
let _ = response_tx.send(batch_result);
|
||||
}
|
||||
}
|
||||
|
||||
/// Gracefully shutdown the batcher
|
||||
pub async fn shutdown(&self) {
|
||||
let mut shutdown = self.shutdown.lock().await;
|
||||
*shutdown = true;
|
||||
}
|
||||
|
||||
/// Get current queue size
|
||||
pub async fn queue_size(&self) -> usize {
|
||||
self.queue.lock().await.len()
|
||||
}
|
||||
|
||||
/// Get current queue statistics
|
||||
pub async fn stats(&self) -> BatchStats {
|
||||
let queue = self.queue.lock().await;
|
||||
let queue_size = queue.len();
|
||||
|
||||
let max_wait = queue
|
||||
.front()
|
||||
.map(|item| item.enqueued_at.elapsed())
|
||||
.unwrap_or(Duration::from_secs(0));
|
||||
|
||||
BatchStats {
|
||||
queue_size,
|
||||
max_wait_time: max_wait,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Batch statistics
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BatchStats {
|
||||
pub queue_size: usize,
|
||||
pub max_wait_time: Duration,
|
||||
}
|
||||
|
||||
/// Adaptive batcher that adjusts batch size based on latency
|
||||
pub struct AdaptiveBatcher<T, R> {
|
||||
inner: DynamicBatcher<T, R>,
|
||||
config: Arc<Mutex<BatchConfig>>,
|
||||
latency_history: Arc<Mutex<VecDeque<Duration>>>,
|
||||
target_latency: Duration,
|
||||
}
|
||||
|
||||
impl<T, R> AdaptiveBatcher<T, R>
|
||||
where
|
||||
T: Send + 'static,
|
||||
R: Send + 'static,
|
||||
{
|
||||
/// Create adaptive batcher with target latency
|
||||
pub fn new<F>(initial_config: BatchConfig, target_latency: Duration, processor: F) -> Self
|
||||
where
|
||||
F: Fn(Vec<T>) -> Vec<Result<R, String>> + Send + Sync + 'static,
|
||||
{
|
||||
let config = Arc::new(Mutex::new(initial_config.clone()));
|
||||
let inner = DynamicBatcher::new(initial_config, processor);
|
||||
|
||||
Self {
|
||||
inner,
|
||||
config,
|
||||
latency_history: Arc::new(Mutex::new(VecDeque::with_capacity(100))),
|
||||
target_latency,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add item and adapt batch size
|
||||
pub async fn add(&self, item: T) -> Result<R, BatchError> {
|
||||
let start = Instant::now();
|
||||
let result = self.inner.add(item).await;
|
||||
let latency = start.elapsed();
|
||||
|
||||
// Record latency
|
||||
{
|
||||
let mut history = self.latency_history.lock().await;
|
||||
history.push_back(latency);
|
||||
if history.len() > 100 {
|
||||
history.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
// Adapt batch size every 10 requests
|
||||
{
|
||||
let history = self.latency_history.lock().await;
|
||||
if history.len() % 10 == 0 && history.len() >= 10 {
|
||||
let avg_latency: Duration = history.iter().sum::<Duration>() / history.len() as u32;
|
||||
|
||||
let mut config = self.config.lock().await;
|
||||
if avg_latency > self.target_latency {
|
||||
// Reduce batch size to lower latency
|
||||
config.max_batch_size = (config.max_batch_size * 9 / 10).max(1);
|
||||
} else if avg_latency < self.target_latency / 2 {
|
||||
// Increase batch size for better throughput
|
||||
config.max_batch_size = (config.max_batch_size * 11 / 10).min(128);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Run the batcher
|
||||
pub async fn run(&self) {
|
||||
self.inner.run().await;
|
||||
}
|
||||
|
||||
/// Get current configuration
|
||||
pub async fn current_config(&self) -> BatchConfig {
|
||||
self.config.lock().await.clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_dynamic_batcher() {
|
||||
let config = BatchConfig {
|
||||
max_batch_size: 4,
|
||||
max_wait_ms: 100,
|
||||
max_queue_size: 100,
|
||||
preferred_batch_size: 2,
|
||||
};
|
||||
|
||||
let batcher = Arc::new(DynamicBatcher::new(config, |items: Vec<i32>| {
|
||||
items.into_iter().map(|x| Ok(x * 2)).collect()
|
||||
}));
|
||||
|
||||
// Start processing loop
|
||||
let batcher_clone = batcher.clone();
|
||||
tokio::spawn(async move {
|
||||
batcher_clone.run().await;
|
||||
});
|
||||
|
||||
// Add items
|
||||
let mut handles = vec![];
|
||||
for i in 0..8 {
|
||||
let batcher = batcher.clone();
|
||||
handles.push(tokio::spawn(async move { batcher.add(i).await }));
|
||||
}
|
||||
|
||||
// Wait for results
|
||||
for (i, handle) in handles.into_iter().enumerate() {
|
||||
let result = handle.await.unwrap().unwrap();
|
||||
assert_eq!(result, (i as i32) * 2);
|
||||
}
|
||||
|
||||
batcher.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_batch_stats() {
|
||||
let config = BatchConfig::default();
|
||||
let batcher = DynamicBatcher::new(config, |items: Vec<i32>| {
|
||||
items.into_iter().map(|x| Ok(x)).collect()
|
||||
});
|
||||
|
||||
// Queue some items without processing
|
||||
let _ = batcher.add(1);
|
||||
let _ = batcher.add(2);
|
||||
let _ = batcher.add(3);
|
||||
|
||||
let stats = batcher.stats().await;
|
||||
assert_eq!(stats.queue_size, 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_queue_full() {
|
||||
let config = BatchConfig {
|
||||
max_queue_size: 2,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let batcher = DynamicBatcher::new(config, |items: Vec<i32>| {
|
||||
std::thread::sleep(Duration::from_secs(1)); // Slow processing
|
||||
items.into_iter().map(|x| Ok(x)).collect()
|
||||
});
|
||||
|
||||
// Fill queue
|
||||
let _ = batcher.add(1);
|
||||
let _ = batcher.add(2);
|
||||
|
||||
// This should fail - queue is full
|
||||
let result = batcher.add(3).await;
|
||||
assert!(matches!(result, Err(BatchError::QueueFull)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_adaptive_batcher() {
|
||||
let config = BatchConfig {
|
||||
max_batch_size: 8,
|
||||
max_wait_ms: 50,
|
||||
max_queue_size: 100,
|
||||
preferred_batch_size: 4,
|
||||
};
|
||||
|
||||
let batcher = Arc::new(AdaptiveBatcher::new(
|
||||
config,
|
||||
Duration::from_millis(100),
|
||||
|items: Vec<i32>| items.into_iter().map(|x| Ok(x * 2)).collect(),
|
||||
));
|
||||
|
||||
let batcher_clone = batcher.clone();
|
||||
tokio::spawn(async move {
|
||||
batcher_clone.run().await;
|
||||
});
|
||||
|
||||
// Process some requests
|
||||
for i in 0..20 {
|
||||
let result = batcher.add(i).await.unwrap();
|
||||
assert_eq!(result, i * 2);
|
||||
}
|
||||
|
||||
// Configuration should have adapted
|
||||
let final_config = batcher.current_config().await;
|
||||
assert!(final_config.max_batch_size > 0);
|
||||
}
|
||||
}
|
||||
409
examples/scipix/src/optimize/memory.rs
Normal file
409
examples/scipix/src/optimize/memory.rs
Normal file
@@ -0,0 +1,409 @@
|
||||
//! Memory optimization utilities
|
||||
//!
|
||||
//! Provides object pooling, memory-mapped file loading, and zero-copy operations.
|
||||
|
||||
use memmap2::{Mmap, MmapOptions};
|
||||
use std::collections::VecDeque;
|
||||
use std::fs::File;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
use super::memory_opt_enabled;
|
||||
use crate::error::{Result, ScipixError};
|
||||
|
||||
/// Object pool for reusable buffers
|
||||
pub struct BufferPool<T> {
|
||||
pool: Arc<Mutex<VecDeque<T>>>,
|
||||
factory: Arc<dyn Fn() -> T + Send + Sync>,
|
||||
#[allow(dead_code)]
|
||||
max_size: usize,
|
||||
}
|
||||
|
||||
impl<T: Send + 'static> BufferPool<T> {
|
||||
/// Create a new buffer pool
|
||||
pub fn new<F>(factory: F, initial_size: usize, max_size: usize) -> Self
|
||||
where
|
||||
F: Fn() -> T + Send + Sync + 'static,
|
||||
{
|
||||
let factory = Arc::new(factory);
|
||||
let pool = Arc::new(Mutex::new(VecDeque::with_capacity(max_size)));
|
||||
|
||||
// Pre-allocate initial buffers
|
||||
if memory_opt_enabled() {
|
||||
let mut pool_lock = pool.lock().unwrap();
|
||||
for _ in 0..initial_size {
|
||||
pool_lock.push_back(factory());
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
pool,
|
||||
factory,
|
||||
max_size,
|
||||
}
|
||||
}
|
||||
|
||||
/// Acquire a buffer from the pool
|
||||
pub fn acquire(&self) -> PooledBuffer<T> {
|
||||
let buffer = if memory_opt_enabled() {
|
||||
self.pool
|
||||
.lock()
|
||||
.unwrap()
|
||||
.pop_front()
|
||||
.unwrap_or_else(|| (self.factory)())
|
||||
} else {
|
||||
(self.factory)()
|
||||
};
|
||||
|
||||
PooledBuffer {
|
||||
buffer: Some(buffer),
|
||||
pool: self.pool.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current pool size
|
||||
pub fn size(&self) -> usize {
|
||||
self.pool.lock().unwrap().len()
|
||||
}
|
||||
|
||||
/// Clear the pool
|
||||
pub fn clear(&self) {
|
||||
self.pool.lock().unwrap().clear();
|
||||
}
|
||||
}
|
||||
|
||||
/// RAII guard for pooled buffers
|
||||
pub struct PooledBuffer<T> {
|
||||
buffer: Option<T>,
|
||||
pool: Arc<Mutex<VecDeque<T>>>,
|
||||
}
|
||||
|
||||
impl<T> PooledBuffer<T> {
|
||||
/// Get mutable reference to buffer
|
||||
pub fn get_mut(&mut self) -> &mut T {
|
||||
self.buffer.as_mut().unwrap()
|
||||
}
|
||||
|
||||
/// Get immutable reference to buffer
|
||||
pub fn get(&self) -> &T {
|
||||
self.buffer.as_ref().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for PooledBuffer<T> {
|
||||
fn drop(&mut self) {
|
||||
if memory_opt_enabled() {
|
||||
if let Some(buffer) = self.buffer.take() {
|
||||
let mut pool = self.pool.lock().unwrap();
|
||||
pool.push_back(buffer);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::ops::Deref for PooledBuffer<T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
self.buffer.as_ref().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> std::ops::DerefMut for PooledBuffer<T> {
|
||||
fn deref_mut(&mut self) -> &mut Self::Target {
|
||||
self.buffer.as_mut().unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
/// Memory-mapped model file
|
||||
pub struct MmapModel {
|
||||
_mmap: Mmap,
|
||||
data: *const u8,
|
||||
len: usize,
|
||||
}
|
||||
|
||||
unsafe impl Send for MmapModel {}
|
||||
unsafe impl Sync for MmapModel {}
|
||||
|
||||
impl MmapModel {
|
||||
/// Load model from file using memory mapping
|
||||
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
|
||||
let file = File::open(path.as_ref()).map_err(|e| ScipixError::Io(e))?;
|
||||
|
||||
let mmap = unsafe {
|
||||
MmapOptions::new()
|
||||
.map(&file)
|
||||
.map_err(|e| ScipixError::Io(e))?
|
||||
};
|
||||
|
||||
let data = mmap.as_ptr();
|
||||
let len = mmap.len();
|
||||
|
||||
Ok(Self {
|
||||
_mmap: mmap,
|
||||
data,
|
||||
len,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get slice of model data
|
||||
pub fn as_slice(&self) -> &[u8] {
|
||||
unsafe { std::slice::from_raw_parts(self.data, self.len) }
|
||||
}
|
||||
|
||||
/// Get size of mapped region
|
||||
pub fn len(&self) -> usize {
|
||||
self.len
|
||||
}
|
||||
|
||||
/// Check if empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len == 0
|
||||
}
|
||||
}
|
||||
|
||||
/// Zero-copy image view
|
||||
pub struct ImageView<'a> {
|
||||
data: &'a [u8],
|
||||
width: u32,
|
||||
height: u32,
|
||||
channels: u8,
|
||||
}
|
||||
|
||||
impl<'a> ImageView<'a> {
|
||||
/// Create new image view from raw data
|
||||
pub fn new(data: &'a [u8], width: u32, height: u32, channels: u8) -> Result<Self> {
|
||||
let expected_len = (width * height * channels as u32) as usize;
|
||||
if data.len() != expected_len {
|
||||
return Err(ScipixError::InvalidInput(format!(
|
||||
"Invalid data length: expected {}, got {}",
|
||||
expected_len,
|
||||
data.len()
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
data,
|
||||
width,
|
||||
height,
|
||||
channels,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get pixel at (x, y)
|
||||
pub fn pixel(&self, x: u32, y: u32) -> &[u8] {
|
||||
let offset = ((y * self.width + x) * self.channels as u32) as usize;
|
||||
&self.data[offset..offset + self.channels as usize]
|
||||
}
|
||||
|
||||
/// Get raw data slice
|
||||
pub fn data(&self) -> &[u8] {
|
||||
self.data
|
||||
}
|
||||
|
||||
/// Get dimensions
|
||||
pub fn dimensions(&self) -> (u32, u32) {
|
||||
(self.width, self.height)
|
||||
}
|
||||
|
||||
/// Get number of channels
|
||||
pub fn channels(&self) -> u8 {
|
||||
self.channels
|
||||
}
|
||||
|
||||
/// Create subview (region of interest)
|
||||
pub fn subview(&self, x: u32, y: u32, width: u32, height: u32) -> Result<Self> {
|
||||
if x + width > self.width || y + height > self.height {
|
||||
return Err(ScipixError::InvalidInput(
|
||||
"Subview out of bounds".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// For simplicity, this creates a copy. True zero-copy would need stride support
|
||||
let mut subview_data = Vec::new();
|
||||
for row in y..y + height {
|
||||
let start = ((row * self.width + x) * self.channels as u32) as usize;
|
||||
let end = start + (width * self.channels as u32) as usize;
|
||||
subview_data.extend_from_slice(&self.data[start..end]);
|
||||
}
|
||||
|
||||
// This temporarily leaks memory - in production, use arena allocator
|
||||
let leaked = Box::leak(subview_data.into_boxed_slice());
|
||||
|
||||
Ok(Self {
|
||||
data: leaked,
|
||||
width,
|
||||
height,
|
||||
channels: self.channels,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Arena allocator for temporary allocations
|
||||
pub struct Arena {
|
||||
buffer: Vec<u8>,
|
||||
offset: usize,
|
||||
}
|
||||
|
||||
impl Arena {
|
||||
/// Create new arena with capacity
|
||||
pub fn with_capacity(capacity: usize) -> Self {
|
||||
Self {
|
||||
buffer: Vec::with_capacity(capacity),
|
||||
offset: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocate aligned memory
|
||||
pub fn alloc(&mut self, size: usize, align: usize) -> &mut [u8] {
|
||||
// Align offset
|
||||
let padding = (align - (self.offset % align)) % align;
|
||||
self.offset += padding;
|
||||
|
||||
let start = self.offset;
|
||||
let end = start + size;
|
||||
|
||||
if end > self.buffer.capacity() {
|
||||
// Grow buffer
|
||||
self.buffer.reserve(end - self.buffer.len());
|
||||
}
|
||||
|
||||
unsafe {
|
||||
self.buffer.set_len(end);
|
||||
}
|
||||
|
||||
self.offset = end;
|
||||
&mut self.buffer[start..end]
|
||||
}
|
||||
|
||||
/// Reset arena (keeps capacity)
|
||||
pub fn reset(&mut self) {
|
||||
self.offset = 0;
|
||||
self.buffer.clear();
|
||||
}
|
||||
|
||||
/// Get current usage
|
||||
pub fn usage(&self) -> usize {
|
||||
self.offset
|
||||
}
|
||||
|
||||
/// Get capacity
|
||||
pub fn capacity(&self) -> usize {
|
||||
self.buffer.capacity()
|
||||
}
|
||||
}
|
||||
|
||||
/// Global buffer pools for common sizes
|
||||
pub struct GlobalPools {
|
||||
small: BufferPool<Vec<u8>>, // 1KB buffers
|
||||
medium: BufferPool<Vec<u8>>, // 64KB buffers
|
||||
large: BufferPool<Vec<u8>>, // 1MB buffers
|
||||
}
|
||||
|
||||
impl GlobalPools {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
small: BufferPool::new(|| Vec::with_capacity(1024), 10, 100),
|
||||
medium: BufferPool::new(|| Vec::with_capacity(64 * 1024), 5, 50),
|
||||
large: BufferPool::new(|| Vec::with_capacity(1024 * 1024), 2, 20),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the global pools instance
|
||||
pub fn get() -> &'static Self {
|
||||
static POOLS: std::sync::OnceLock<GlobalPools> = std::sync::OnceLock::new();
|
||||
POOLS.get_or_init(GlobalPools::new)
|
||||
}
|
||||
|
||||
/// Acquire small buffer (1KB)
|
||||
pub fn acquire_small(&self) -> PooledBuffer<Vec<u8>> {
|
||||
self.small.acquire()
|
||||
}
|
||||
|
||||
/// Acquire medium buffer (64KB)
|
||||
pub fn acquire_medium(&self) -> PooledBuffer<Vec<u8>> {
|
||||
self.medium.acquire()
|
||||
}
|
||||
|
||||
/// Acquire large buffer (1MB)
|
||||
pub fn acquire_large(&self) -> PooledBuffer<Vec<u8>> {
|
||||
self.large.acquire()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Write;
|
||||
use tempfile::NamedTempFile;
|
||||
|
||||
#[test]
|
||||
fn test_buffer_pool() {
|
||||
let pool = BufferPool::new(|| Vec::with_capacity(1024), 2, 10);
|
||||
|
||||
assert_eq!(pool.size(), 2);
|
||||
|
||||
let mut buf1 = pool.acquire();
|
||||
assert_eq!(buf1.capacity(), 1024);
|
||||
buf1.extend_from_slice(b"test");
|
||||
|
||||
drop(buf1);
|
||||
assert_eq!(pool.size(), 3); // Returned to pool
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mmap_model() {
|
||||
let mut temp = NamedTempFile::new().unwrap();
|
||||
temp.write_all(b"test model data").unwrap();
|
||||
temp.flush().unwrap();
|
||||
|
||||
let mmap = MmapModel::from_file(temp.path()).unwrap();
|
||||
assert_eq!(mmap.as_slice(), b"test model data");
|
||||
assert_eq!(mmap.len(), 15);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_image_view() {
|
||||
let data = vec![
|
||||
255, 0, 0, 255, // Red pixel
|
||||
0, 255, 0, 255, // Green pixel
|
||||
0, 0, 255, 255, // Blue pixel
|
||||
255, 255, 255, 255, // White pixel
|
||||
];
|
||||
|
||||
let view = ImageView::new(&data, 2, 2, 4).unwrap();
|
||||
assert_eq!(view.dimensions(), (2, 2));
|
||||
assert_eq!(view.pixel(0, 0), &[255, 0, 0, 255]);
|
||||
assert_eq!(view.pixel(1, 1), &[255, 255, 255, 255]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_arena() {
|
||||
let mut arena = Arena::with_capacity(1024);
|
||||
|
||||
let slice1 = arena.alloc(100, 8);
|
||||
assert_eq!(slice1.len(), 100);
|
||||
|
||||
let slice2 = arena.alloc(200, 8);
|
||||
assert_eq!(slice2.len(), 200);
|
||||
|
||||
assert!(arena.usage() >= 300);
|
||||
|
||||
arena.reset();
|
||||
assert_eq!(arena.usage(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_global_pools() {
|
||||
let pools = GlobalPools::get();
|
||||
|
||||
let small = pools.acquire_small();
|
||||
assert!(small.capacity() >= 1024);
|
||||
|
||||
let medium = pools.acquire_medium();
|
||||
assert!(medium.capacity() >= 64 * 1024);
|
||||
|
||||
let large = pools.acquire_large();
|
||||
assert!(large.capacity() >= 1024 * 1024);
|
||||
}
|
||||
}
|
||||
169
examples/scipix/src/optimize/mod.rs
Normal file
169
examples/scipix/src/optimize/mod.rs
Normal file
@@ -0,0 +1,169 @@
|
||||
//! Performance optimization utilities for scipix OCR
|
||||
//!
|
||||
//! This module provides runtime feature detection and optimized code paths
|
||||
//! for different CPU architectures and capabilities.
|
||||
|
||||
pub mod batch;
|
||||
pub mod memory;
|
||||
pub mod parallel;
|
||||
pub mod quantize;
|
||||
pub mod simd;
|
||||
|
||||
use std::sync::OnceLock;
|
||||
|
||||
/// CPU features detected at runtime
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct CpuFeatures {
|
||||
pub avx2: bool,
|
||||
pub avx512f: bool,
|
||||
pub neon: bool,
|
||||
pub sse4_2: bool,
|
||||
}
|
||||
|
||||
static CPU_FEATURES: OnceLock<CpuFeatures> = OnceLock::new();
|
||||
|
||||
/// Detect CPU features at runtime
|
||||
pub fn detect_features() -> CpuFeatures {
|
||||
*CPU_FEATURES.get_or_init(|| {
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
{
|
||||
CpuFeatures {
|
||||
avx2: is_x86_feature_detected!("avx2"),
|
||||
avx512f: is_x86_feature_detected!("avx512f"),
|
||||
neon: false,
|
||||
sse4_2: is_x86_feature_detected!("sse4.2"),
|
||||
}
|
||||
}
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
{
|
||||
CpuFeatures {
|
||||
avx2: false,
|
||||
avx512f: false,
|
||||
neon: std::arch::is_aarch64_feature_detected!("neon"),
|
||||
sse4_2: false,
|
||||
}
|
||||
}
|
||||
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
|
||||
{
|
||||
CpuFeatures {
|
||||
avx2: false,
|
||||
avx512f: false,
|
||||
neon: false,
|
||||
sse4_2: false,
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the detected CPU features
|
||||
pub fn get_features() -> CpuFeatures {
|
||||
detect_features()
|
||||
}
|
||||
|
||||
/// Runtime dispatch to optimized implementation
|
||||
pub trait OptimizedOp<T> {
|
||||
/// Execute the operation with the best available implementation
|
||||
fn execute(&self, input: T) -> T;
|
||||
|
||||
/// Execute with SIMD if available, fallback to scalar
|
||||
fn execute_auto(&self, input: T) -> T {
|
||||
let features = get_features();
|
||||
if features.avx2 || features.avx512f || features.neon {
|
||||
self.execute_simd(input)
|
||||
} else {
|
||||
self.execute_scalar(input)
|
||||
}
|
||||
}
|
||||
|
||||
/// SIMD implementation
|
||||
fn execute_simd(&self, input: T) -> T;
|
||||
|
||||
/// Scalar fallback implementation
|
||||
fn execute_scalar(&self, input: T) -> T;
|
||||
}
|
||||
|
||||
/// Optimization level configuration
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum OptLevel {
|
||||
/// No optimizations, scalar code only
|
||||
None,
|
||||
/// Use SIMD when available
|
||||
Simd,
|
||||
/// Use SIMD + parallel processing
|
||||
Parallel,
|
||||
/// All optimizations including memory optimizations
|
||||
Full,
|
||||
}
|
||||
|
||||
impl Default for OptLevel {
|
||||
fn default() -> Self {
|
||||
OptLevel::Full
|
||||
}
|
||||
}
|
||||
|
||||
/// Global optimization configuration
|
||||
static OPT_LEVEL: OnceLock<OptLevel> = OnceLock::new();
|
||||
|
||||
/// Set the optimization level
|
||||
pub fn set_opt_level(level: OptLevel) {
|
||||
OPT_LEVEL.set(level).ok();
|
||||
}
|
||||
|
||||
/// Get the current optimization level
|
||||
pub fn get_opt_level() -> OptLevel {
|
||||
*OPT_LEVEL.get_or_init(OptLevel::default)
|
||||
}
|
||||
|
||||
/// Check if SIMD optimizations are enabled
|
||||
pub fn simd_enabled() -> bool {
|
||||
matches!(
|
||||
get_opt_level(),
|
||||
OptLevel::Simd | OptLevel::Parallel | OptLevel::Full
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if parallel optimizations are enabled
|
||||
pub fn parallel_enabled() -> bool {
|
||||
matches!(get_opt_level(), OptLevel::Parallel | OptLevel::Full)
|
||||
}
|
||||
|
||||
/// Check if memory optimizations are enabled
|
||||
pub fn memory_opt_enabled() -> bool {
|
||||
matches!(get_opt_level(), OptLevel::Full)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_feature_detection() {
|
||||
let features = detect_features();
|
||||
println!("Detected features: {:?}", features);
|
||||
|
||||
// Should always succeed on any platform
|
||||
assert!(
|
||||
features.avx2
|
||||
|| features.avx512f
|
||||
|| features.neon
|
||||
|| features.sse4_2
|
||||
|| (!features.avx2 && !features.avx512f && !features.neon && !features.sse4_2)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_opt_level() {
|
||||
assert_eq!(get_opt_level(), OptLevel::Full);
|
||||
|
||||
set_opt_level(OptLevel::Simd);
|
||||
// Can't change after first init, should still be Full
|
||||
assert_eq!(get_opt_level(), OptLevel::Full);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_optimization_checks() {
|
||||
assert!(simd_enabled());
|
||||
assert!(parallel_enabled());
|
||||
assert!(memory_opt_enabled());
|
||||
}
|
||||
}
|
||||
335
examples/scipix/src/optimize/parallel.rs
Normal file
335
examples/scipix/src/optimize/parallel.rs
Normal file
@@ -0,0 +1,335 @@
|
||||
//! Parallel processing utilities for OCR pipeline
|
||||
//!
|
||||
//! Provides parallel image preprocessing, batch OCR, and pipelined execution.
|
||||
|
||||
use image::DynamicImage;
|
||||
use rayon::prelude::*;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Semaphore;
|
||||
|
||||
use super::parallel_enabled;
|
||||
|
||||
/// Parallel preprocessing of multiple images
|
||||
pub fn parallel_preprocess<F>(images: Vec<DynamicImage>, preprocess_fn: F) -> Vec<DynamicImage>
|
||||
where
|
||||
F: Fn(DynamicImage) -> DynamicImage + Sync + Send,
|
||||
{
|
||||
if !parallel_enabled() {
|
||||
return images.into_iter().map(preprocess_fn).collect();
|
||||
}
|
||||
|
||||
images.into_par_iter().map(preprocess_fn).collect()
|
||||
}
|
||||
|
||||
/// Parallel processing with error handling
|
||||
pub fn parallel_preprocess_result<F, E>(
|
||||
images: Vec<DynamicImage>,
|
||||
preprocess_fn: F,
|
||||
) -> Vec<std::result::Result<DynamicImage, E>>
|
||||
where
|
||||
F: Fn(DynamicImage) -> std::result::Result<DynamicImage, E> + Sync + Send,
|
||||
E: Send,
|
||||
{
|
||||
if !parallel_enabled() {
|
||||
return images.into_iter().map(preprocess_fn).collect();
|
||||
}
|
||||
|
||||
images.into_par_iter().map(preprocess_fn).collect()
|
||||
}
|
||||
|
||||
/// Pipeline parallel execution for OCR workflow
|
||||
///
|
||||
/// Executes stages in a pipeline: preprocess | detect | recognize
|
||||
/// Each stage can start processing the next item while previous stages
|
||||
/// continue with subsequent items.
|
||||
pub struct PipelineExecutor<T, U, V> {
|
||||
stage1: Arc<dyn Fn(T) -> U + Send + Sync>,
|
||||
stage2: Arc<dyn Fn(U) -> V + Send + Sync>,
|
||||
}
|
||||
|
||||
impl<T, U, V> PipelineExecutor<T, U, V>
|
||||
where
|
||||
T: Send,
|
||||
U: Send,
|
||||
V: Send,
|
||||
{
|
||||
pub fn new<F1, F2>(stage1: F1, stage2: F2) -> Self
|
||||
where
|
||||
F1: Fn(T) -> U + Send + Sync + 'static,
|
||||
F2: Fn(U) -> V + Send + Sync + 'static,
|
||||
{
|
||||
Self {
|
||||
stage1: Arc::new(stage1),
|
||||
stage2: Arc::new(stage2),
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute pipeline on multiple inputs
|
||||
pub fn execute_batch(&self, inputs: Vec<T>) -> Vec<V> {
|
||||
if !parallel_enabled() {
|
||||
return inputs
|
||||
.into_iter()
|
||||
.map(|input| {
|
||||
let stage1_out = (self.stage1)(input);
|
||||
(self.stage2)(stage1_out)
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
|
||||
inputs
|
||||
.into_par_iter()
|
||||
.map(|input| {
|
||||
let stage1_out = (self.stage1)(input);
|
||||
(self.stage2)(stage1_out)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Three-stage pipeline executor
|
||||
pub struct Pipeline3<T, U, V, W> {
|
||||
stage1: Arc<dyn Fn(T) -> U + Send + Sync>,
|
||||
stage2: Arc<dyn Fn(U) -> V + Send + Sync>,
|
||||
stage3: Arc<dyn Fn(V) -> W + Send + Sync>,
|
||||
}
|
||||
|
||||
impl<T, U, V, W> Pipeline3<T, U, V, W>
|
||||
where
|
||||
T: Send,
|
||||
U: Send,
|
||||
V: Send,
|
||||
W: Send,
|
||||
{
|
||||
pub fn new<F1, F2, F3>(stage1: F1, stage2: F2, stage3: F3) -> Self
|
||||
where
|
||||
F1: Fn(T) -> U + Send + Sync + 'static,
|
||||
F2: Fn(U) -> V + Send + Sync + 'static,
|
||||
F3: Fn(V) -> W + Send + Sync + 'static,
|
||||
{
|
||||
Self {
|
||||
stage1: Arc::new(stage1),
|
||||
stage2: Arc::new(stage2),
|
||||
stage3: Arc::new(stage3),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn execute_batch(&self, inputs: Vec<T>) -> Vec<W> {
|
||||
if !parallel_enabled() {
|
||||
return inputs
|
||||
.into_iter()
|
||||
.map(|input| {
|
||||
let out1 = (self.stage1)(input);
|
||||
let out2 = (self.stage2)(out1);
|
||||
(self.stage3)(out2)
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
|
||||
inputs
|
||||
.into_par_iter()
|
||||
.map(|input| {
|
||||
let out1 = (self.stage1)(input);
|
||||
let out2 = (self.stage2)(out1);
|
||||
(self.stage3)(out2)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
/// Parallel map with configurable chunk size
|
||||
pub fn parallel_map_chunked<T, U, F>(items: Vec<T>, chunk_size: usize, map_fn: F) -> Vec<U>
|
||||
where
|
||||
T: Send,
|
||||
U: Send,
|
||||
F: Fn(T) -> U + Sync + Send,
|
||||
{
|
||||
if !parallel_enabled() {
|
||||
return items.into_iter().map(map_fn).collect();
|
||||
}
|
||||
|
||||
items
|
||||
.into_par_iter()
|
||||
.with_min_len(chunk_size)
|
||||
.map(map_fn)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Async parallel executor with concurrency limit
|
||||
pub struct AsyncParallelExecutor {
|
||||
semaphore: Arc<Semaphore>,
|
||||
}
|
||||
|
||||
impl AsyncParallelExecutor {
|
||||
/// Create executor with maximum concurrency limit
|
||||
pub fn new(max_concurrent: usize) -> Self {
|
||||
Self {
|
||||
semaphore: Arc::new(Semaphore::new(max_concurrent)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute async tasks with concurrency limit
|
||||
pub async fn execute<T, F, Fut>(&self, tasks: Vec<T>, executor: F) -> Vec<Fut::Output>
|
||||
where
|
||||
T: Send + 'static,
|
||||
F: Fn(T) -> Fut + Send + Sync + Clone + 'static,
|
||||
Fut: std::future::Future + Send + 'static,
|
||||
Fut::Output: Send + 'static,
|
||||
{
|
||||
let mut handles = Vec::new();
|
||||
|
||||
for task in tasks {
|
||||
let permit = self.semaphore.clone().acquire_owned().await.unwrap();
|
||||
let executor = executor.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let result = executor(task).await;
|
||||
drop(permit); // Release semaphore
|
||||
result
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// Wait for all tasks to complete
|
||||
let mut results = Vec::new();
|
||||
for handle in handles {
|
||||
if let Ok(result) = handle.await {
|
||||
results.push(result);
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
/// Execute with error handling
|
||||
pub async fn execute_result<T, F, Fut, R, E>(
|
||||
&self,
|
||||
tasks: Vec<T>,
|
||||
executor: F,
|
||||
) -> Vec<std::result::Result<R, E>>
|
||||
where
|
||||
T: Send + 'static,
|
||||
F: Fn(T) -> Fut + Send + Sync + Clone + 'static,
|
||||
Fut: std::future::Future<Output = std::result::Result<R, E>> + Send + 'static,
|
||||
R: Send + 'static,
|
||||
E: Send + 'static,
|
||||
{
|
||||
let mut handles = Vec::new();
|
||||
|
||||
for task in tasks {
|
||||
let permit = self.semaphore.clone().acquire_owned().await.unwrap();
|
||||
let executor = executor.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let result = executor(task).await;
|
||||
drop(permit);
|
||||
result
|
||||
});
|
||||
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
let mut results = Vec::new();
|
||||
for handle in handles {
|
||||
match handle.await {
|
||||
Ok(result) => results.push(result),
|
||||
Err(_) => continue, // Task panicked
|
||||
}
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
}
|
||||
|
||||
/// Work-stealing parallel iterator for unbalanced workloads
|
||||
pub fn parallel_unbalanced<T, U, F>(items: Vec<T>, map_fn: F) -> Vec<U>
|
||||
where
|
||||
T: Send,
|
||||
U: Send,
|
||||
F: Fn(T) -> U + Sync + Send,
|
||||
{
|
||||
if !parallel_enabled() {
|
||||
return items.into_iter().map(map_fn).collect();
|
||||
}
|
||||
|
||||
// Use adaptive strategy for unbalanced work
|
||||
items
|
||||
.into_par_iter()
|
||||
.with_min_len(1) // Allow fine-grained work stealing
|
||||
.map(map_fn)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get optimal thread count for current system
|
||||
pub fn optimal_thread_count() -> usize {
|
||||
rayon::current_num_threads()
|
||||
}
|
||||
|
||||
/// Set global thread pool size
|
||||
pub fn set_thread_count(threads: usize) {
|
||||
rayon::ThreadPoolBuilder::new()
|
||||
.num_threads(threads)
|
||||
.build_global()
|
||||
.ok();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parallel_map() {
|
||||
let data: Vec<i32> = (0..100).collect();
|
||||
let result = parallel_map_chunked(data, 10, |x| x * 2);
|
||||
|
||||
assert_eq!(result.len(), 100);
|
||||
assert_eq!(result[0], 0);
|
||||
assert_eq!(result[50], 100);
|
||||
assert_eq!(result[99], 198);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pipeline_executor() {
|
||||
let pipeline = PipelineExecutor::new(|x: i32| x + 1, |x: i32| x * 2);
|
||||
|
||||
let inputs = vec![1, 2, 3, 4, 5];
|
||||
let results = pipeline.execute_batch(inputs);
|
||||
|
||||
assert_eq!(results, vec![4, 6, 8, 10, 12]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pipeline3() {
|
||||
let pipeline = Pipeline3::new(|x: i32| x + 1, |x: i32| x * 2, |x: i32| x - 1);
|
||||
|
||||
let inputs = vec![1, 2, 3];
|
||||
let results = pipeline.execute_batch(inputs);
|
||||
|
||||
// (1+1)*2-1 = 3, (2+1)*2-1 = 5, (3+1)*2-1 = 7
|
||||
assert_eq!(results, vec![3, 5, 7]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_async_executor() {
|
||||
let executor = AsyncParallelExecutor::new(2);
|
||||
|
||||
let tasks = vec![1, 2, 3, 4, 5];
|
||||
let results = executor
|
||||
.execute(tasks, |x| async move {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
|
||||
x * 2
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(results.len(), 5);
|
||||
assert!(results.contains(&2));
|
||||
assert!(results.contains(&10));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_optimal_threads() {
|
||||
let threads = optimal_thread_count();
|
||||
assert!(threads > 0);
|
||||
assert!(threads <= num_cpus::get());
|
||||
}
|
||||
}
|
||||
339
examples/scipix/src/optimize/quantize.rs
Normal file
339
examples/scipix/src/optimize/quantize.rs
Normal file
@@ -0,0 +1,339 @@
|
||||
//! Model quantization utilities
|
||||
//!
|
||||
//! Provides INT8 quantization for model weights and activations to reduce
|
||||
//! memory usage and improve inference speed.
|
||||
|
||||
use std::f32;
|
||||
|
||||
/// Quantization parameters
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct QuantParams {
|
||||
pub scale: f32,
|
||||
pub zero_point: i8,
|
||||
}
|
||||
|
||||
impl QuantParams {
|
||||
/// Calculate quantization parameters from min/max values
|
||||
pub fn from_range(min: f32, max: f32) -> Self {
|
||||
let qmin = i8::MIN as f32;
|
||||
let qmax = i8::MAX as f32;
|
||||
|
||||
let scale = (max - min) / (qmax - qmin);
|
||||
let zero_point = (qmin - min / scale).round() as i8;
|
||||
|
||||
Self { scale, zero_point }
|
||||
}
|
||||
|
||||
/// Calculate from data statistics
|
||||
pub fn from_data(data: &[f32]) -> Self {
|
||||
let min = data.iter().copied().fold(f32::INFINITY, f32::min);
|
||||
let max = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
|
||||
Self::from_range(min, max)
|
||||
}
|
||||
|
||||
/// Symmetric quantization (zero_point = 0)
|
||||
pub fn symmetric(abs_max: f32) -> Self {
|
||||
let scale = abs_max / 127.0;
|
||||
Self {
|
||||
scale,
|
||||
zero_point: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Quantize f32 weights to i8
|
||||
pub fn quantize_weights(weights: &[f32]) -> (Vec<i8>, QuantParams) {
|
||||
let params = QuantParams::from_data(weights);
|
||||
let quantized = quantize_with_params(weights, params);
|
||||
(quantized, params)
|
||||
}
|
||||
|
||||
/// Quantize with given parameters
|
||||
pub fn quantize_with_params(weights: &[f32], params: QuantParams) -> Vec<i8> {
|
||||
weights.iter().map(|&w| quantize_value(w, params)).collect()
|
||||
}
|
||||
|
||||
/// Quantize single value
|
||||
#[inline]
|
||||
pub fn quantize_value(value: f32, params: QuantParams) -> i8 {
|
||||
let scaled = value / params.scale + params.zero_point as f32;
|
||||
scaled.round().clamp(i8::MIN as f32, i8::MAX as f32) as i8
|
||||
}
|
||||
|
||||
/// Dequantize i8 to f32
|
||||
pub fn dequantize(quantized: &[i8], params: QuantParams) -> Vec<f32> {
|
||||
quantized
|
||||
.iter()
|
||||
.map(|&q| dequantize_value(q, params))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Dequantize single value
|
||||
#[inline]
|
||||
pub fn dequantize_value(quantized: i8, params: QuantParams) -> f32 {
|
||||
(quantized as f32 - params.zero_point as f32) * params.scale
|
||||
}
|
||||
|
||||
/// Quantized tensor representation
|
||||
pub struct QuantizedTensor {
|
||||
pub data: Vec<i8>,
|
||||
pub params: QuantParams,
|
||||
pub shape: Vec<usize>,
|
||||
}
|
||||
|
||||
impl QuantizedTensor {
|
||||
/// Create from f32 tensor
|
||||
pub fn from_f32(data: &[f32], shape: Vec<usize>) -> Self {
|
||||
let (quantized, params) = quantize_weights(data);
|
||||
Self {
|
||||
data: quantized,
|
||||
params,
|
||||
shape,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with symmetric quantization
|
||||
pub fn from_f32_symmetric(data: &[f32], shape: Vec<usize>) -> Self {
|
||||
let abs_max = data.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
|
||||
let params = QuantParams::symmetric(abs_max);
|
||||
let quantized = quantize_with_params(data, params);
|
||||
|
||||
Self {
|
||||
data: quantized,
|
||||
params,
|
||||
shape,
|
||||
}
|
||||
}
|
||||
|
||||
/// Dequantize to f32
|
||||
pub fn to_f32(&self) -> Vec<f32> {
|
||||
dequantize(&self.data, self.params)
|
||||
}
|
||||
|
||||
/// Get size in bytes
|
||||
pub fn size_bytes(&self) -> usize {
|
||||
self.data.len()
|
||||
+ std::mem::size_of::<QuantParams>()
|
||||
+ self.shape.len() * std::mem::size_of::<usize>()
|
||||
}
|
||||
|
||||
/// Calculate memory savings vs f32
|
||||
pub fn compression_ratio(&self) -> f32 {
|
||||
let f32_size = self.data.len() * std::mem::size_of::<f32>();
|
||||
let quantized_size = self.size_bytes();
|
||||
f32_size as f32 / quantized_size as f32
|
||||
}
|
||||
}
|
||||
|
||||
/// Per-channel quantization for conv/linear layers
|
||||
pub struct PerChannelQuant {
|
||||
pub data: Vec<i8>,
|
||||
pub params: Vec<QuantParams>,
|
||||
pub shape: Vec<usize>,
|
||||
}
|
||||
|
||||
impl PerChannelQuant {
|
||||
/// Quantize with per-channel parameters
|
||||
/// For a weight tensor of shape [out_channels, in_channels, ...],
|
||||
/// use separate params for each output channel
|
||||
pub fn from_f32(data: &[f32], shape: Vec<usize>) -> Self {
|
||||
if shape.is_empty() {
|
||||
panic!("Shape cannot be empty");
|
||||
}
|
||||
|
||||
let out_channels = shape[0];
|
||||
let channel_size = data.len() / out_channels;
|
||||
|
||||
let mut all_quantized = Vec::with_capacity(data.len());
|
||||
let mut params = Vec::with_capacity(out_channels);
|
||||
|
||||
for ch in 0..out_channels {
|
||||
let start = ch * channel_size;
|
||||
let end = start + channel_size;
|
||||
let channel_data = &data[start..end];
|
||||
|
||||
let ch_params = QuantParams::from_data(channel_data);
|
||||
let ch_quantized = quantize_with_params(channel_data, ch_params);
|
||||
|
||||
all_quantized.extend(ch_quantized);
|
||||
params.push(ch_params);
|
||||
}
|
||||
|
||||
Self {
|
||||
data: all_quantized,
|
||||
params,
|
||||
shape,
|
||||
}
|
||||
}
|
||||
|
||||
/// Dequantize to f32
|
||||
pub fn to_f32(&self) -> Vec<f32> {
|
||||
let out_channels = self.shape[0];
|
||||
let channel_size = self.data.len() / out_channels;
|
||||
|
||||
let mut result = Vec::with_capacity(self.data.len());
|
||||
|
||||
for ch in 0..out_channels {
|
||||
let start = ch * channel_size;
|
||||
let end = start + channel_size;
|
||||
let channel_data = &self.data[start..end];
|
||||
let ch_params = self.params[ch];
|
||||
|
||||
result.extend(dequantize(channel_data, ch_params));
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
/// Dynamic quantization - quantize at runtime
|
||||
pub struct DynamicQuantizer {
|
||||
percentile: f32,
|
||||
}
|
||||
|
||||
impl DynamicQuantizer {
|
||||
/// Create quantizer with calibration percentile
|
||||
/// percentile: clip values beyond this percentile (e.g., 99.9)
|
||||
pub fn new(percentile: f32) -> Self {
|
||||
Self { percentile }
|
||||
}
|
||||
|
||||
/// Quantize with calibration
|
||||
pub fn quantize(&self, data: &[f32]) -> (Vec<i8>, QuantParams) {
|
||||
let mut sorted: Vec<f32> = data.iter().copied().collect();
|
||||
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap());
|
||||
|
||||
let idx = ((sorted.len() as f32 * self.percentile / 100.0) as usize).min(sorted.len() - 1);
|
||||
|
||||
let min = -sorted[sorted.len() - idx];
|
||||
let max = sorted[idx];
|
||||
|
||||
let params = QuantParams::from_range(min, max);
|
||||
let quantized = quantize_with_params(data, params);
|
||||
|
||||
(quantized, params)
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate quantization error (MSE)
|
||||
pub fn quantization_error(original: &[f32], quantized: &[i8], params: QuantParams) -> f32 {
|
||||
let dequantized = dequantize(quantized, params);
|
||||
|
||||
let mse: f32 = original
|
||||
.iter()
|
||||
.zip(dequantized.iter())
|
||||
.map(|(o, d)| (o - d).powi(2))
|
||||
.sum::<f32>()
|
||||
/ original.len() as f32;
|
||||
|
||||
mse
|
||||
}
|
||||
|
||||
/// Calculate signal-to-quantization-noise ratio (SQNR) in dB
|
||||
pub fn sqnr(original: &[f32], quantized: &[i8], params: QuantParams) -> f32 {
|
||||
let dequantized = dequantize(quantized, params);
|
||||
|
||||
let signal_power: f32 = original.iter().map(|x| x.powi(2)).sum::<f32>() / original.len() as f32;
|
||||
let noise_power: f32 = original
|
||||
.iter()
|
||||
.zip(dequantized.iter())
|
||||
.map(|(o, d)| (o - d).powi(2))
|
||||
.sum::<f32>()
|
||||
/ original.len() as f32;
|
||||
|
||||
10.0 * (signal_power / noise_power).log10()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_quantize_dequantize() {
|
||||
let weights = vec![0.0, 0.5, 1.0, -0.5, -1.0];
|
||||
let (quantized, params) = quantize_weights(&weights);
|
||||
let dequantized = dequantize(&quantized, params);
|
||||
|
||||
// Check approximate equality
|
||||
for (orig, deq) in weights.iter().zip(dequantized.iter()) {
|
||||
assert!((orig - deq).abs() < 0.01, "orig: {}, deq: {}", orig, deq);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_symmetric_quantization() {
|
||||
let data = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
|
||||
let params = QuantParams::symmetric(1.0);
|
||||
|
||||
assert_eq!(params.zero_point, 0);
|
||||
assert!((params.scale - 1.0 / 127.0).abs() < 1e-6);
|
||||
|
||||
let quantized = quantize_with_params(&data, params);
|
||||
assert_eq!(quantized[2], 0); // 0.0 should map to 0
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantized_tensor() {
|
||||
let data = vec![1.0, 2.0, 3.0, 4.0];
|
||||
let tensor = QuantizedTensor::from_f32(&data, vec![2, 2]);
|
||||
|
||||
assert_eq!(tensor.shape, vec![2, 2]);
|
||||
assert_eq!(tensor.data.len(), 4);
|
||||
|
||||
let dequantized = tensor.to_f32();
|
||||
for (orig, deq) in data.iter().zip(dequantized.iter()) {
|
||||
assert!((orig - deq).abs() < 0.1);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_per_channel_quant() {
|
||||
// 2 channels, 3 values each
|
||||
let data = vec![
|
||||
1.0, 2.0, 3.0, // Channel 0
|
||||
10.0, 20.0, 30.0, // Channel 1
|
||||
];
|
||||
|
||||
let quant = PerChannelQuant::from_f32(&data, vec![2, 3]);
|
||||
assert_eq!(quant.params.len(), 2);
|
||||
|
||||
let dequantized = quant.to_f32();
|
||||
for (orig, deq) in data.iter().zip(dequantized.iter()) {
|
||||
assert!((orig - deq).abs() < 1.0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantization_error() {
|
||||
let original = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
let (quantized, params) = quantize_weights(&original);
|
||||
|
||||
let error = quantization_error(&original, &quantized, params);
|
||||
assert!(error < 0.1); // Should be small for simple data
|
||||
|
||||
let snr = sqnr(&original, &quantized, params);
|
||||
assert!(snr > 30.0); // Should have good SNR
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compression_ratio() {
|
||||
let data: Vec<f32> = (0..1000).map(|i| i as f32 / 1000.0).collect();
|
||||
let tensor = QuantizedTensor::from_f32(&data, vec![1000]);
|
||||
|
||||
let ratio = tensor.compression_ratio();
|
||||
assert!(ratio > 3.5); // Should be ~4x compression
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dynamic_quantizer() {
|
||||
let mut data: Vec<f32> = (0..100).map(|i| i as f32).collect();
|
||||
data.push(1000.0); // Outlier
|
||||
|
||||
let quantizer = DynamicQuantizer::new(99.0);
|
||||
let (quantized, params) = quantizer.quantize(&data);
|
||||
|
||||
assert_eq!(quantized.len(), 101);
|
||||
// The outlier should be clipped
|
||||
assert!(params.scale > 0.0);
|
||||
}
|
||||
}
|
||||
597
examples/scipix/src/optimize/simd.rs
Normal file
597
examples/scipix/src/optimize/simd.rs
Normal file
@@ -0,0 +1,597 @@
|
||||
//! SIMD-accelerated image processing operations
|
||||
//!
|
||||
//! Provides optimized implementations for common image operations using
|
||||
//! AVX2, AVX-512, and ARM NEON intrinsics.
|
||||
|
||||
use super::{get_features, simd_enabled};
|
||||
|
||||
/// Convert RGBA image to grayscale using optimized SIMD operations
|
||||
pub fn simd_grayscale(rgba: &[u8], gray: &mut [u8]) {
|
||||
if !simd_enabled() {
|
||||
return scalar_grayscale(rgba, gray);
|
||||
}
|
||||
|
||||
let features = get_features();
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
{
|
||||
if features.avx2 {
|
||||
unsafe { avx2_grayscale(rgba, gray) }
|
||||
} else if features.sse4_2 {
|
||||
unsafe { sse_grayscale(rgba, gray) }
|
||||
} else {
|
||||
scalar_grayscale(rgba, gray)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
{
|
||||
if features.neon {
|
||||
unsafe { neon_grayscale(rgba, gray) }
|
||||
} else {
|
||||
scalar_grayscale(rgba, gray)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
|
||||
{
|
||||
scalar_grayscale(rgba, gray)
|
||||
}
|
||||
}
|
||||
|
||||
/// Scalar fallback for grayscale conversion
|
||||
fn scalar_grayscale(rgba: &[u8], gray: &mut [u8]) {
|
||||
assert_eq!(
|
||||
rgba.len() / 4,
|
||||
gray.len(),
|
||||
"RGBA length must be 4x grayscale length"
|
||||
);
|
||||
|
||||
for (i, chunk) in rgba.chunks_exact(4).enumerate() {
|
||||
let r = chunk[0] as u32;
|
||||
let g = chunk[1] as u32;
|
||||
let b = chunk[2] as u32;
|
||||
|
||||
// ITU-R BT.601 luma coefficients: 0.299 R + 0.587 G + 0.114 B
|
||||
gray[i] = ((r * 77 + g * 150 + b * 29) >> 8) as u8;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
unsafe fn avx2_grayscale(rgba: &[u8], gray: &mut [u8]) {
|
||||
use std::arch::x86_64::*;
|
||||
|
||||
let len = gray.len();
|
||||
let mut i = 0;
|
||||
|
||||
// Process 8 pixels at a time (32 RGBA bytes)
|
||||
while i + 8 <= len {
|
||||
// Load 32 bytes (8 RGBA pixels)
|
||||
let rgba_ptr = rgba.as_ptr().add(i * 4);
|
||||
let _pixels = _mm256_loadu_si256(rgba_ptr as *const __m256i);
|
||||
|
||||
// Separate RGBA channels (simplified - actual implementation would use shuffles)
|
||||
// For production, use proper channel extraction
|
||||
|
||||
// Store grayscale result
|
||||
for j in 0..8 {
|
||||
let pixel_idx = (i + j) * 4;
|
||||
let r = *rgba.get_unchecked(pixel_idx) as u32;
|
||||
let g = *rgba.get_unchecked(pixel_idx + 1) as u32;
|
||||
let b = *rgba.get_unchecked(pixel_idx + 2) as u32;
|
||||
*gray.get_unchecked_mut(i + j) = ((r * 77 + g * 150 + b * 29) >> 8) as u8;
|
||||
}
|
||||
|
||||
i += 8;
|
||||
}
|
||||
|
||||
// Handle remaining pixels
|
||||
scalar_grayscale(&rgba[i * 4..], &mut gray[i..]);
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "sse4.2")]
|
||||
unsafe fn sse_grayscale(rgba: &[u8], gray: &mut [u8]) {
|
||||
#[allow(unused_imports)]
|
||||
use std::arch::x86_64::*;
|
||||
|
||||
let len = gray.len();
|
||||
let mut i = 0;
|
||||
|
||||
// Process 4 pixels at a time (16 RGBA bytes)
|
||||
while i + 4 <= len {
|
||||
for j in 0..4 {
|
||||
let pixel_idx = (i + j) * 4;
|
||||
let r = *rgba.get_unchecked(pixel_idx) as u32;
|
||||
let g = *rgba.get_unchecked(pixel_idx + 1) as u32;
|
||||
let b = *rgba.get_unchecked(pixel_idx + 2) as u32;
|
||||
*gray.get_unchecked_mut(i + j) = ((r * 77 + g * 150 + b * 29) >> 8) as u8;
|
||||
}
|
||||
i += 4;
|
||||
}
|
||||
|
||||
scalar_grayscale(&rgba[i * 4..], &mut gray[i..]);
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "aarch64")]
|
||||
unsafe fn neon_grayscale(rgba: &[u8], gray: &mut [u8]) {
|
||||
use std::arch::aarch64::*;
|
||||
|
||||
let len = gray.len();
|
||||
let mut i = 0;
|
||||
|
||||
// Process 8 pixels at a time
|
||||
while i + 8 <= len {
|
||||
for j in 0..8 {
|
||||
let idx = (i + j) * 4;
|
||||
let r = *rgba.get_unchecked(idx) as u32;
|
||||
let g = *rgba.get_unchecked(idx + 1) as u32;
|
||||
let b = *rgba.get_unchecked(idx + 2) as u32;
|
||||
*gray.get_unchecked_mut(i + j) = ((r * 77 + g * 150 + b * 29) >> 8) as u8;
|
||||
}
|
||||
i += 8;
|
||||
}
|
||||
|
||||
scalar_grayscale(&rgba[i * 4..], &mut gray[i..]);
|
||||
}
|
||||
|
||||
/// Apply threshold to grayscale image using SIMD
|
||||
pub fn simd_threshold(gray: &[u8], thresh: u8, out: &mut [u8]) {
|
||||
if !simd_enabled() {
|
||||
return scalar_threshold(gray, thresh, out);
|
||||
}
|
||||
|
||||
let features = get_features();
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
{
|
||||
if features.avx2 {
|
||||
unsafe { avx2_threshold(gray, thresh, out) }
|
||||
} else {
|
||||
scalar_threshold(gray, thresh, out)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "x86_64"))]
|
||||
{
|
||||
scalar_threshold(gray, thresh, out)
|
||||
}
|
||||
}
|
||||
|
||||
fn scalar_threshold(gray: &[u8], thresh: u8, out: &mut [u8]) {
|
||||
for (g, o) in gray.iter().zip(out.iter_mut()) {
|
||||
*o = if *g >= thresh { 255 } else { 0 };
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
unsafe fn avx2_threshold(gray: &[u8], thresh: u8, out: &mut [u8]) {
|
||||
use std::arch::x86_64::*;
|
||||
|
||||
let len = gray.len();
|
||||
let mut i = 0;
|
||||
|
||||
let thresh_vec = _mm256_set1_epi8(thresh as i8);
|
||||
let ones = _mm256_set1_epi8(-1); // 0xFF
|
||||
|
||||
// Process 32 bytes at a time
|
||||
while i + 32 <= len {
|
||||
let gray_vec = _mm256_loadu_si256(gray.as_ptr().add(i) as *const __m256i);
|
||||
let cmp = _mm256_cmpgt_epi8(gray_vec, thresh_vec);
|
||||
let result = _mm256_and_si256(cmp, ones);
|
||||
_mm256_storeu_si256(out.as_mut_ptr().add(i) as *mut __m256i, result);
|
||||
i += 32;
|
||||
}
|
||||
|
||||
// Handle remaining bytes
|
||||
scalar_threshold(&gray[i..], thresh, &mut out[i..]);
|
||||
}
|
||||
|
||||
/// Normalize f32 tensor data using SIMD
|
||||
pub fn simd_normalize(data: &mut [f32]) {
|
||||
if !simd_enabled() {
|
||||
return scalar_normalize(data);
|
||||
}
|
||||
|
||||
let features = get_features();
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
{
|
||||
if features.avx2 {
|
||||
unsafe { avx2_normalize(data) }
|
||||
} else {
|
||||
scalar_normalize(data)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "x86_64"))]
|
||||
{
|
||||
scalar_normalize(data)
|
||||
}
|
||||
}
|
||||
|
||||
fn scalar_normalize(data: &mut [f32]) {
|
||||
let sum: f32 = data.iter().sum();
|
||||
let mean = sum / data.len() as f32;
|
||||
|
||||
let variance: f32 = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
|
||||
let std_dev = variance.sqrt() + 1e-8; // Add epsilon for numerical stability
|
||||
|
||||
for x in data.iter_mut() {
|
||||
*x = (*x - mean) / std_dev;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
unsafe fn avx2_normalize(data: &mut [f32]) {
|
||||
use std::arch::x86_64::*;
|
||||
|
||||
// Calculate mean using SIMD
|
||||
let len = data.len();
|
||||
let mut sum = _mm256_setzero_ps();
|
||||
let mut i = 0;
|
||||
|
||||
while i + 8 <= len {
|
||||
let vals = _mm256_loadu_ps(data.as_ptr().add(i));
|
||||
sum = _mm256_add_ps(sum, vals);
|
||||
i += 8;
|
||||
}
|
||||
|
||||
// Horizontal sum
|
||||
let sum_scalar = {
|
||||
let sum_arr: [f32; 8] = std::mem::transmute(sum);
|
||||
sum_arr.iter().sum::<f32>() + data[i..].iter().sum::<f32>()
|
||||
};
|
||||
|
||||
let mean = sum_scalar / len as f32;
|
||||
let mean_vec = _mm256_set1_ps(mean);
|
||||
|
||||
// Calculate variance
|
||||
let mut var_sum = _mm256_setzero_ps();
|
||||
i = 0;
|
||||
|
||||
while i + 8 <= len {
|
||||
let vals = _mm256_loadu_ps(data.as_ptr().add(i));
|
||||
let diff = _mm256_sub_ps(vals, mean_vec);
|
||||
let sq = _mm256_mul_ps(diff, diff);
|
||||
var_sum = _mm256_add_ps(var_sum, sq);
|
||||
i += 8;
|
||||
}
|
||||
|
||||
let var_scalar = {
|
||||
let var_arr: [f32; 8] = std::mem::transmute(var_sum);
|
||||
var_arr.iter().sum::<f32>() + data[i..].iter().map(|x| (x - mean).powi(2)).sum::<f32>()
|
||||
};
|
||||
|
||||
let std_dev = (var_scalar / len as f32).sqrt() + 1e-8;
|
||||
let std_vec = _mm256_set1_ps(std_dev);
|
||||
|
||||
// Normalize
|
||||
i = 0;
|
||||
while i + 8 <= len {
|
||||
let vals = _mm256_loadu_ps(data.as_ptr().add(i));
|
||||
let centered = _mm256_sub_ps(vals, mean_vec);
|
||||
let normalized = _mm256_div_ps(centered, std_vec);
|
||||
_mm256_storeu_ps(data.as_mut_ptr().add(i), normalized);
|
||||
i += 8;
|
||||
}
|
||||
|
||||
// Handle remaining elements
|
||||
for x in &mut data[i..] {
|
||||
*x = (*x - mean) / std_dev;
|
||||
}
|
||||
}
|
||||
|
||||
/// Fast bilinear resize using SIMD - optimized for preprocessing
|
||||
/// This is significantly faster than the image crate's resize for typical OCR sizes
|
||||
pub fn simd_resize_bilinear(
|
||||
src: &[u8],
|
||||
src_width: usize,
|
||||
src_height: usize,
|
||||
dst_width: usize,
|
||||
dst_height: usize,
|
||||
) -> Vec<u8> {
|
||||
if !simd_enabled() {
|
||||
return scalar_resize_bilinear(src, src_width, src_height, dst_width, dst_height);
|
||||
}
|
||||
|
||||
let features = get_features();
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
{
|
||||
if features.avx2 {
|
||||
unsafe { avx2_resize_bilinear(src, src_width, src_height, dst_width, dst_height) }
|
||||
} else {
|
||||
scalar_resize_bilinear(src, src_width, src_height, dst_width, dst_height)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "x86_64"))]
|
||||
{
|
||||
scalar_resize_bilinear(src, src_width, src_height, dst_width, dst_height)
|
||||
}
|
||||
}
|
||||
|
||||
/// Scalar bilinear resize implementation
|
||||
fn scalar_resize_bilinear(
|
||||
src: &[u8],
|
||||
src_width: usize,
|
||||
src_height: usize,
|
||||
dst_width: usize,
|
||||
dst_height: usize,
|
||||
) -> Vec<u8> {
|
||||
let mut dst = vec![0u8; dst_width * dst_height];
|
||||
|
||||
let x_scale = src_width as f32 / dst_width as f32;
|
||||
let y_scale = src_height as f32 / dst_height as f32;
|
||||
|
||||
for y in 0..dst_height {
|
||||
let src_y = y as f32 * y_scale;
|
||||
let y0 = (src_y.floor() as usize).min(src_height - 1);
|
||||
let y1 = (y0 + 1).min(src_height - 1);
|
||||
let y_frac = src_y - src_y.floor();
|
||||
|
||||
for x in 0..dst_width {
|
||||
let src_x = x as f32 * x_scale;
|
||||
let x0 = (src_x.floor() as usize).min(src_width - 1);
|
||||
let x1 = (x0 + 1).min(src_width - 1);
|
||||
let x_frac = src_x - src_x.floor();
|
||||
|
||||
// Bilinear interpolation
|
||||
let p00 = src[y0 * src_width + x0] as f32;
|
||||
let p10 = src[y0 * src_width + x1] as f32;
|
||||
let p01 = src[y1 * src_width + x0] as f32;
|
||||
let p11 = src[y1 * src_width + x1] as f32;
|
||||
|
||||
let top = p00 * (1.0 - x_frac) + p10 * x_frac;
|
||||
let bottom = p01 * (1.0 - x_frac) + p11 * x_frac;
|
||||
let value = top * (1.0 - y_frac) + bottom * y_frac;
|
||||
|
||||
dst[y * dst_width + x] = value.round() as u8;
|
||||
}
|
||||
}
|
||||
|
||||
dst
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[target_feature(enable = "avx2")]
|
||||
unsafe fn avx2_resize_bilinear(
|
||||
src: &[u8],
|
||||
src_width: usize,
|
||||
src_height: usize,
|
||||
dst_width: usize,
|
||||
dst_height: usize,
|
||||
) -> Vec<u8> {
|
||||
use std::arch::x86_64::*;
|
||||
|
||||
let mut dst = vec![0u8; dst_width * dst_height];
|
||||
|
||||
let x_scale = src_width as f32 / dst_width as f32;
|
||||
let y_scale = src_height as f32 / dst_height as f32;
|
||||
|
||||
// Process 8 output pixels at a time for x dimension
|
||||
for y in 0..dst_height {
|
||||
let src_y = y as f32 * y_scale;
|
||||
let y0 = (src_y.floor() as usize).min(src_height - 1);
|
||||
let y1 = (y0 + 1).min(src_height - 1);
|
||||
let _y_frac = _mm256_set1_ps(src_y - src_y.floor());
|
||||
let _y_frac_inv = _mm256_set1_ps(1.0 - (src_y - src_y.floor()));
|
||||
|
||||
let mut x = 0;
|
||||
while x + 8 <= dst_width {
|
||||
// Calculate source x coordinates for 8 destination pixels
|
||||
let src_xs: [f32; 8] = [
|
||||
(x) as f32 * x_scale,
|
||||
(x + 1) as f32 * x_scale,
|
||||
(x + 2) as f32 * x_scale,
|
||||
(x + 3) as f32 * x_scale,
|
||||
(x + 4) as f32 * x_scale,
|
||||
(x + 5) as f32 * x_scale,
|
||||
(x + 6) as f32 * x_scale,
|
||||
(x + 7) as f32 * x_scale,
|
||||
];
|
||||
|
||||
let mut results = [0u8; 8];
|
||||
for i in 0..8 {
|
||||
let src_x = src_xs[i];
|
||||
let x0 = (src_x.floor() as usize).min(src_width - 1);
|
||||
let x1 = (x0 + 1).min(src_width - 1);
|
||||
let x_frac = src_x - src_x.floor();
|
||||
|
||||
let p00 = *src.get_unchecked(y0 * src_width + x0) as f32;
|
||||
let p10 = *src.get_unchecked(y0 * src_width + x1) as f32;
|
||||
let p01 = *src.get_unchecked(y1 * src_width + x0) as f32;
|
||||
let p11 = *src.get_unchecked(y1 * src_width + x1) as f32;
|
||||
|
||||
let top = p00 * (1.0 - x_frac) + p10 * x_frac;
|
||||
let bottom = p01 * (1.0 - x_frac) + p11 * x_frac;
|
||||
let value =
|
||||
top * (1.0 - (src_y - src_y.floor())) + bottom * (src_y - src_y.floor());
|
||||
results[i] = value.round() as u8;
|
||||
}
|
||||
|
||||
for i in 0..8 {
|
||||
*dst.get_unchecked_mut(y * dst_width + x + i) = results[i];
|
||||
}
|
||||
x += 8;
|
||||
}
|
||||
|
||||
// Handle remaining pixels
|
||||
while x < dst_width {
|
||||
let src_x = x as f32 * x_scale;
|
||||
let x0 = (src_x.floor() as usize).min(src_width - 1);
|
||||
let x1 = (x0 + 1).min(src_width - 1);
|
||||
let x_frac = src_x - src_x.floor();
|
||||
|
||||
let p00 = *src.get_unchecked(y0 * src_width + x0) as f32;
|
||||
let p10 = *src.get_unchecked(y0 * src_width + x1) as f32;
|
||||
let p01 = *src.get_unchecked(y1 * src_width + x0) as f32;
|
||||
let p11 = *src.get_unchecked(y1 * src_width + x1) as f32;
|
||||
|
||||
let top = p00 * (1.0 - x_frac) + p10 * x_frac;
|
||||
let bottom = p01 * (1.0 - x_frac) + p11 * x_frac;
|
||||
let value = top * (1.0 - (src_y - src_y.floor())) + bottom * (src_y - src_y.floor());
|
||||
*dst.get_unchecked_mut(y * dst_width + x) = value.round() as u8;
|
||||
x += 1;
|
||||
}
|
||||
}
|
||||
|
||||
dst
|
||||
}
|
||||
|
||||
/// Parallel SIMD resize for large images - splits work across threads
|
||||
#[cfg(feature = "rayon")]
|
||||
pub fn parallel_simd_resize(
|
||||
src: &[u8],
|
||||
src_width: usize,
|
||||
src_height: usize,
|
||||
dst_width: usize,
|
||||
dst_height: usize,
|
||||
) -> Vec<u8> {
|
||||
use rayon::prelude::*;
|
||||
|
||||
// For small images, use single-threaded SIMD
|
||||
if dst_height < 64 || dst_width * dst_height < 100_000 {
|
||||
return simd_resize_bilinear(src, src_width, src_height, dst_width, dst_height);
|
||||
}
|
||||
|
||||
let mut dst = vec![0u8; dst_width * dst_height];
|
||||
let x_scale = src_width as f32 / dst_width as f32;
|
||||
let y_scale = src_height as f32 / dst_height as f32;
|
||||
|
||||
// Process rows in parallel
|
||||
dst.par_chunks_mut(dst_width)
|
||||
.enumerate()
|
||||
.for_each(|(y, row)| {
|
||||
let src_y = y as f32 * y_scale;
|
||||
let y0 = (src_y.floor() as usize).min(src_height - 1);
|
||||
let y1 = (y0 + 1).min(src_height - 1);
|
||||
let y_frac = src_y - src_y.floor();
|
||||
|
||||
for x in 0..dst_width {
|
||||
let src_x = x as f32 * x_scale;
|
||||
let x0 = (src_x.floor() as usize).min(src_width - 1);
|
||||
let x1 = (x0 + 1).min(src_width - 1);
|
||||
let x_frac = src_x - src_x.floor();
|
||||
|
||||
let p00 = src[y0 * src_width + x0] as f32;
|
||||
let p10 = src[y0 * src_width + x1] as f32;
|
||||
let p01 = src[y1 * src_width + x0] as f32;
|
||||
let p11 = src[y1 * src_width + x1] as f32;
|
||||
|
||||
let top = p00 * (1.0 - x_frac) + p10 * x_frac;
|
||||
let bottom = p01 * (1.0 - x_frac) + p11 * x_frac;
|
||||
let value = top * (1.0 - y_frac) + bottom * y_frac;
|
||||
|
||||
row[x] = value.round() as u8;
|
||||
}
|
||||
});
|
||||
|
||||
dst
|
||||
}
|
||||
|
||||
/// Ultra-fast area average downscaling for preprocessing
|
||||
/// Best for large images being scaled down significantly
|
||||
pub fn fast_area_resize(
|
||||
src: &[u8],
|
||||
src_width: usize,
|
||||
src_height: usize,
|
||||
dst_width: usize,
|
||||
dst_height: usize,
|
||||
) -> Vec<u8> {
|
||||
// Only use area averaging for downscaling
|
||||
if dst_width >= src_width || dst_height >= src_height {
|
||||
return simd_resize_bilinear(src, src_width, src_height, dst_width, dst_height);
|
||||
}
|
||||
|
||||
let mut dst = vec![0u8; dst_width * dst_height];
|
||||
|
||||
let x_ratio = src_width as f32 / dst_width as f32;
|
||||
let y_ratio = src_height as f32 / dst_height as f32;
|
||||
|
||||
for y in 0..dst_height {
|
||||
let y_start = (y as f32 * y_ratio) as usize;
|
||||
let y_end = (((y + 1) as f32 * y_ratio) as usize).min(src_height);
|
||||
|
||||
for x in 0..dst_width {
|
||||
let x_start = (x as f32 * x_ratio) as usize;
|
||||
let x_end = (((x + 1) as f32 * x_ratio) as usize).min(src_width);
|
||||
|
||||
// Calculate area average
|
||||
let mut sum: u32 = 0;
|
||||
let mut count: u32 = 0;
|
||||
|
||||
for sy in y_start..y_end {
|
||||
for sx in x_start..x_end {
|
||||
sum += src[sy * src_width + sx] as u32;
|
||||
count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
dst[y * dst_width + x] = if count > 0 { (sum / count) as u8 } else { 0 };
|
||||
}
|
||||
}
|
||||
|
||||
dst
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_grayscale_conversion() {
|
||||
let rgba = vec![
|
||||
255, 0, 0, 255, // Red
|
||||
0, 255, 0, 255, // Green
|
||||
0, 0, 255, 255, // Blue
|
||||
255, 255, 255, 255, // White
|
||||
];
|
||||
let mut gray = vec![0u8; 4];
|
||||
|
||||
simd_grayscale(&rgba, &mut gray);
|
||||
|
||||
// Check approximately correct values
|
||||
assert!(gray[0] > 50 && gray[0] < 100); // Red
|
||||
assert!(gray[1] > 130 && gray[1] < 160); // Green
|
||||
assert!(gray[2] > 20 && gray[2] < 50); // Blue
|
||||
assert_eq!(gray[3], 255); // White
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_threshold() {
|
||||
let gray = vec![0, 50, 100, 150, 200, 255];
|
||||
let mut out = vec![0u8; 6];
|
||||
|
||||
simd_threshold(&gray, 100, &mut out);
|
||||
|
||||
assert_eq!(out, vec![0, 0, 0, 255, 255, 255]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize() {
|
||||
let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
|
||||
simd_normalize(&mut data);
|
||||
|
||||
// After normalization, mean should be ~0 and std dev ~1
|
||||
let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
|
||||
assert!(mean.abs() < 1e-6);
|
||||
}
|
||||
|
||||
#[cfg(target_arch = "x86_64")]
|
||||
#[test]
|
||||
fn test_simd_vs_scalar_grayscale() {
|
||||
let rgba: Vec<u8> = (0..1024).map(|i| (i % 256) as u8).collect();
|
||||
let mut gray_simd = vec![0u8; 256];
|
||||
let mut gray_scalar = vec![0u8; 256];
|
||||
|
||||
simd_grayscale(&rgba, &mut gray_simd);
|
||||
scalar_grayscale(&rgba, &mut gray_scalar);
|
||||
|
||||
assert_eq!(gray_simd, gray_scalar);
|
||||
}
|
||||
}
|
||||
298
examples/scipix/src/output/docx.rs
Normal file
298
examples/scipix/src/output/docx.rs
Normal file
@@ -0,0 +1,298 @@
|
||||
//! DOCX (Microsoft Word) formatter with Office Math ML support
|
||||
//!
|
||||
//! This is a stub implementation. Full DOCX generation requires:
|
||||
//! - ZIP file creation for .docx format
|
||||
//! - XML generation for document.xml, styles.xml, etc.
|
||||
//! - Office Math ML for equations
|
||||
//! - Image embedding support
|
||||
//!
|
||||
//! Consider using libraries like `docx-rs` for production implementation.
|
||||
|
||||
use super::{LineData, OcrResult};
|
||||
use std::io::Write;
|
||||
|
||||
/// DOCX formatter (stub implementation)
|
||||
#[allow(dead_code)]
|
||||
pub struct DocxFormatter {
|
||||
include_styles: bool,
|
||||
page_size: PageSize,
|
||||
margins: Margins,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct PageSize {
|
||||
pub width: u32, // in twips (1/1440 inch)
|
||||
pub height: u32,
|
||||
}
|
||||
|
||||
impl PageSize {
|
||||
pub fn letter() -> Self {
|
||||
Self {
|
||||
width: 12240, // 8.5 inches
|
||||
height: 15840, // 11 inches
|
||||
}
|
||||
}
|
||||
|
||||
pub fn a4() -> Self {
|
||||
Self {
|
||||
width: 11906, // 210mm
|
||||
height: 16838, // 297mm
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct Margins {
|
||||
pub top: u32,
|
||||
pub right: u32,
|
||||
pub bottom: u32,
|
||||
pub left: u32,
|
||||
}
|
||||
|
||||
impl Margins {
|
||||
pub fn normal() -> Self {
|
||||
Self {
|
||||
top: 1440, // 1 inch
|
||||
right: 1440,
|
||||
bottom: 1440,
|
||||
left: 1440,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl DocxFormatter {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
include_styles: true,
|
||||
page_size: PageSize::letter(),
|
||||
margins: Margins::normal(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_page_size(mut self, page_size: PageSize) -> Self {
|
||||
self.page_size = page_size;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_margins(mut self, margins: Margins) -> Self {
|
||||
self.margins = margins;
|
||||
self
|
||||
}
|
||||
|
||||
/// Generate Office Math ML from LaTeX
|
||||
/// This is a simplified placeholder - real implementation needs proper conversion
|
||||
pub fn latex_to_mathml(&self, latex: &str) -> String {
|
||||
// This is a very simplified stub
|
||||
// Real implementation would parse LaTeX and generate proper Office Math ML
|
||||
format!(
|
||||
r#"<m:oMathPara>
|
||||
<m:oMath>
|
||||
<m:r>
|
||||
<m:t>{}</m:t>
|
||||
</m:r>
|
||||
</m:oMath>
|
||||
</m:oMathPara>"#,
|
||||
self.escape_xml(latex)
|
||||
)
|
||||
}
|
||||
|
||||
/// Generate document.xml content
|
||||
pub fn generate_document_xml(&self, lines: &[LineData]) -> String {
|
||||
let mut xml = String::from(
|
||||
r#"<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
|
||||
<w:document xmlns:w="http://schemas.openxmlformats.org/wordprocessingml/2006/main"
|
||||
xmlns:m="http://schemas.openxmlformats.org/officeDocument/2006/math">
|
||||
<w:body>
|
||||
"#,
|
||||
);
|
||||
|
||||
for line in lines {
|
||||
xml.push_str(&self.format_line(line));
|
||||
}
|
||||
|
||||
xml.push_str(" </w:body>\n</w:document>");
|
||||
xml
|
||||
}
|
||||
|
||||
fn format_line(&self, line: &LineData) -> String {
|
||||
match line.line_type.as_str() {
|
||||
"text" => self.format_paragraph(&line.text),
|
||||
"math" | "equation" => {
|
||||
let latex = line.latex.as_ref().unwrap_or(&line.text);
|
||||
self.format_math(latex)
|
||||
}
|
||||
"heading" => self.format_heading(&line.text, 1),
|
||||
_ => self.format_paragraph(&line.text),
|
||||
}
|
||||
}
|
||||
|
||||
fn format_paragraph(&self, text: &str) -> String {
|
||||
format!(
|
||||
r#" <w:p>
|
||||
<w:r>
|
||||
<w:t>{}</w:t>
|
||||
</w:r>
|
||||
</w:p>
|
||||
"#,
|
||||
self.escape_xml(text)
|
||||
)
|
||||
}
|
||||
|
||||
fn format_heading(&self, text: &str, level: u32) -> String {
|
||||
format!(
|
||||
r#" <w:p>
|
||||
<w:pPr>
|
||||
<w:pStyle w:val="Heading{}"/>
|
||||
</w:pPr>
|
||||
<w:r>
|
||||
<w:t>{}</w:t>
|
||||
</w:r>
|
||||
</w:p>
|
||||
"#,
|
||||
level,
|
||||
self.escape_xml(text)
|
||||
)
|
||||
}
|
||||
|
||||
fn format_math(&self, latex: &str) -> String {
|
||||
let mathml = self.latex_to_mathml(latex);
|
||||
format!(
|
||||
r#" <w:p>
|
||||
<w:r>
|
||||
{}
|
||||
</w:r>
|
||||
</w:p>
|
||||
"#,
|
||||
mathml
|
||||
)
|
||||
}
|
||||
|
||||
fn escape_xml(&self, text: &str) -> String {
|
||||
text.replace('&', "&")
|
||||
.replace('<', "<")
|
||||
.replace('>', ">")
|
||||
.replace('"', """)
|
||||
.replace('\'', "'")
|
||||
}
|
||||
|
||||
/// Save DOCX to file (stub - needs ZIP implementation)
|
||||
pub fn save_to_file<W: Write>(
|
||||
&self,
|
||||
_writer: &mut W,
|
||||
_result: &OcrResult,
|
||||
) -> Result<(), String> {
|
||||
Err("DOCX binary format generation not implemented. Use docx-rs library for full implementation.".to_string())
|
||||
}
|
||||
|
||||
/// Generate styles.xml content
|
||||
pub fn generate_styles_xml(&self) -> String {
|
||||
r#"<?xml version="1.0" encoding="UTF-8" standalone="yes"?>
|
||||
<w:styles xmlns:w="http://schemas.openxmlformats.org/wordprocessingml/2006/main">
|
||||
<w:style w:type="paragraph" w:styleId="Normal">
|
||||
<w:name w:val="Normal"/>
|
||||
<w:qFormat/>
|
||||
</w:style>
|
||||
<w:style w:type="paragraph" w:styleId="Heading1">
|
||||
<w:name w:val="Heading 1"/>
|
||||
<w:basedOn w:val="Normal"/>
|
||||
<w:qFormat/>
|
||||
<w:pPr>
|
||||
<w:keepNext/>
|
||||
<w:keepLines/>
|
||||
</w:pPr>
|
||||
<w:rPr>
|
||||
<w:b/>
|
||||
<w:sz w:val="32"/>
|
||||
</w:rPr>
|
||||
</w:style>
|
||||
</w:styles>"#
|
||||
.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DocxFormatter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::output::BoundingBox;
|
||||
|
||||
#[test]
|
||||
fn test_page_sizes() {
|
||||
let letter = PageSize::letter();
|
||||
assert_eq!(letter.width, 12240);
|
||||
|
||||
let a4 = PageSize::a4();
|
||||
assert!(a4.width < letter.width);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_escape_xml() {
|
||||
let formatter = DocxFormatter::new();
|
||||
let result = formatter.escape_xml("Test <tag> & \"quote\"");
|
||||
|
||||
assert!(result.contains("<"));
|
||||
assert!(result.contains(">"));
|
||||
assert!(result.contains("&"));
|
||||
assert!(result.contains("""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_paragraph() {
|
||||
let formatter = DocxFormatter::new();
|
||||
let result = formatter.format_paragraph("Hello World");
|
||||
|
||||
assert!(result.contains("<w:p>"));
|
||||
assert!(result.contains("<w:t>Hello World</w:t>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_heading() {
|
||||
let formatter = DocxFormatter::new();
|
||||
let result = formatter.format_heading("Chapter 1", 1);
|
||||
|
||||
assert!(result.contains("Heading1"));
|
||||
assert!(result.contains("Chapter 1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_latex_to_mathml() {
|
||||
let formatter = DocxFormatter::new();
|
||||
let result = formatter.latex_to_mathml("E = mc^2");
|
||||
|
||||
assert!(result.contains("<m:oMath>"));
|
||||
assert!(result.contains("mc^2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_document_xml() {
|
||||
let formatter = DocxFormatter::new();
|
||||
let lines = vec![LineData {
|
||||
line_type: "text".to_string(),
|
||||
text: "Hello".to_string(),
|
||||
latex: None,
|
||||
bbox: BoundingBox::new(0.0, 0.0, 100.0, 20.0),
|
||||
confidence: 0.95,
|
||||
words: None,
|
||||
}];
|
||||
|
||||
let xml = formatter.generate_document_xml(&lines);
|
||||
assert!(xml.contains("<?xml"));
|
||||
assert!(xml.contains("<w:document"));
|
||||
assert!(xml.contains("Hello"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_styles_xml() {
|
||||
let formatter = DocxFormatter::new();
|
||||
let xml = formatter.generate_styles_xml();
|
||||
|
||||
assert!(xml.contains("<w:styles"));
|
||||
assert!(xml.contains("Normal"));
|
||||
assert!(xml.contains("Heading 1"));
|
||||
}
|
||||
}
|
||||
412
examples/scipix/src/output/formatter.rs
Normal file
412
examples/scipix/src/output/formatter.rs
Normal file
@@ -0,0 +1,412 @@
|
||||
//! Multi-format output formatter with batch processing and streaming support
|
||||
|
||||
use super::*;
|
||||
use crate::output::{html, latex, mmd, smiles};
|
||||
use std::io::Write;
|
||||
|
||||
/// Configuration for output formatting
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FormatterConfig {
|
||||
/// Target output formats
|
||||
pub formats: Vec<OutputFormat>,
|
||||
|
||||
/// Enable pretty printing (where applicable)
|
||||
pub pretty: bool,
|
||||
|
||||
/// Include confidence scores in output
|
||||
pub include_confidence: bool,
|
||||
|
||||
/// Include bounding box data
|
||||
pub include_bbox: bool,
|
||||
|
||||
/// Math delimiter style for LaTeX/MMD
|
||||
pub math_delimiters: MathDelimiters,
|
||||
|
||||
/// HTML rendering engine
|
||||
pub html_engine: HtmlEngine,
|
||||
|
||||
/// Enable streaming for large documents
|
||||
pub streaming: bool,
|
||||
}
|
||||
|
||||
impl Default for FormatterConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
formats: vec![OutputFormat::Text],
|
||||
pretty: true,
|
||||
include_confidence: false,
|
||||
include_bbox: false,
|
||||
math_delimiters: MathDelimiters::default(),
|
||||
html_engine: HtmlEngine::MathJax,
|
||||
streaming: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Math delimiter configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MathDelimiters {
|
||||
pub inline_start: String,
|
||||
pub inline_end: String,
|
||||
pub display_start: String,
|
||||
pub display_end: String,
|
||||
}
|
||||
|
||||
impl Default for MathDelimiters {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
inline_start: "$".to_string(),
|
||||
inline_end: "$".to_string(),
|
||||
display_start: "$$".to_string(),
|
||||
display_end: "$$".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// HTML rendering engine options
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum HtmlEngine {
|
||||
MathJax,
|
||||
KaTeX,
|
||||
Raw,
|
||||
}
|
||||
|
||||
/// Main output formatter
|
||||
pub struct OutputFormatter {
|
||||
config: FormatterConfig,
|
||||
}
|
||||
|
||||
impl OutputFormatter {
|
||||
/// Create a new formatter with default configuration
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
config: FormatterConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a formatter with custom configuration
|
||||
pub fn with_config(config: FormatterConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
/// Format a single OCR result
|
||||
pub fn format_result(&self, result: &OcrResult) -> Result<FormatsData, String> {
|
||||
let mut formats = FormatsData::default();
|
||||
|
||||
for format in &self.config.formats {
|
||||
let output = self.format_single(result, *format)?;
|
||||
self.set_format_output(&mut formats, *format, output);
|
||||
}
|
||||
|
||||
Ok(formats)
|
||||
}
|
||||
|
||||
/// Format multiple results in batch
|
||||
pub fn format_batch(&self, results: &[OcrResult]) -> Result<Vec<FormatsData>, String> {
|
||||
results
|
||||
.iter()
|
||||
.map(|result| self.format_result(result))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Stream format results to a writer
|
||||
pub fn format_stream<W: Write>(
|
||||
&self,
|
||||
results: &[OcrResult],
|
||||
writer: &mut W,
|
||||
format: OutputFormat,
|
||||
) -> Result<(), String> {
|
||||
for (i, result) in results.iter().enumerate() {
|
||||
let output = self.format_single(result, format)?;
|
||||
writer
|
||||
.write_all(output.as_bytes())
|
||||
.map_err(|e| format!("Write error: {}", e))?;
|
||||
|
||||
// Add separator between results
|
||||
if i < results.len() - 1 {
|
||||
writer
|
||||
.write_all(b"\n\n---\n\n")
|
||||
.map_err(|e| format!("Write error: {}", e))?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Format a single result to a specific format
|
||||
fn format_single(&self, result: &OcrResult, format: OutputFormat) -> Result<String, String> {
|
||||
match format {
|
||||
OutputFormat::Text => self.format_text(result),
|
||||
OutputFormat::LaTeX => self.format_latex(result, false),
|
||||
OutputFormat::LaTeXStyled => self.format_latex(result, true),
|
||||
OutputFormat::Mmd => self.format_mmd(result),
|
||||
OutputFormat::Html => self.format_html(result),
|
||||
OutputFormat::Smiles => self.format_smiles(result),
|
||||
OutputFormat::Docx => self.format_docx(result),
|
||||
OutputFormat::MathML => self.format_mathml(result),
|
||||
OutputFormat::AsciiMath => self.format_asciimath(result),
|
||||
}
|
||||
}
|
||||
|
||||
fn format_text(&self, result: &OcrResult) -> Result<String, String> {
|
||||
if let Some(text) = &result.formats.text {
|
||||
return Ok(text.clone());
|
||||
}
|
||||
|
||||
// Fallback: extract text from line data
|
||||
if let Some(line_data) = &result.line_data {
|
||||
let text = line_data
|
||||
.iter()
|
||||
.map(|line| line.text.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
return Ok(text);
|
||||
}
|
||||
|
||||
Err("No text content available".to_string())
|
||||
}
|
||||
|
||||
fn format_latex(&self, result: &OcrResult, styled: bool) -> Result<String, String> {
|
||||
let latex_content = if styled {
|
||||
result
|
||||
.formats
|
||||
.latex_styled
|
||||
.as_ref()
|
||||
.or(result.formats.latex_normal.as_ref())
|
||||
} else {
|
||||
result.formats.latex_normal.as_ref()
|
||||
};
|
||||
|
||||
if let Some(latex) = latex_content {
|
||||
if styled {
|
||||
// Wrap in document with packages
|
||||
Ok(latex::LaTeXFormatter::new()
|
||||
.with_packages(vec![
|
||||
"amsmath".to_string(),
|
||||
"amssymb".to_string(),
|
||||
"graphicx".to_string(),
|
||||
])
|
||||
.format_document(latex))
|
||||
} else {
|
||||
Ok(latex.clone())
|
||||
}
|
||||
} else {
|
||||
Err("No LaTeX content available".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn format_mmd(&self, result: &OcrResult) -> Result<String, String> {
|
||||
if let Some(mmd) = &result.formats.mmd {
|
||||
return Ok(mmd.clone());
|
||||
}
|
||||
|
||||
// Generate MMD from line data
|
||||
if let Some(line_data) = &result.line_data {
|
||||
let formatter = mmd::MmdFormatter::with_delimiters(self.config.math_delimiters.clone());
|
||||
return Ok(formatter.format(line_data));
|
||||
}
|
||||
|
||||
Err("No MMD content available".to_string())
|
||||
}
|
||||
|
||||
fn format_html(&self, result: &OcrResult) -> Result<String, String> {
|
||||
if let Some(html) = &result.formats.html {
|
||||
return Ok(html.clone());
|
||||
}
|
||||
|
||||
// Generate HTML with math rendering
|
||||
let content = self.format_text(result)?;
|
||||
let formatter = html::HtmlFormatter::new()
|
||||
.with_engine(self.config.html_engine)
|
||||
.with_styling(self.config.pretty);
|
||||
|
||||
Ok(formatter.format(&content, result.line_data.as_deref()))
|
||||
}
|
||||
|
||||
fn format_smiles(&self, result: &OcrResult) -> Result<String, String> {
|
||||
if let Some(smiles) = &result.formats.smiles {
|
||||
return Ok(smiles.clone());
|
||||
}
|
||||
|
||||
// Generate SMILES if we have chemical structure data
|
||||
let generator = smiles::SmilesGenerator::new();
|
||||
generator.generate_from_result(result)
|
||||
}
|
||||
|
||||
fn format_docx(&self, _result: &OcrResult) -> Result<String, String> {
|
||||
// DOCX requires binary format, return placeholder
|
||||
Err("DOCX format requires binary output - use save_docx() instead".to_string())
|
||||
}
|
||||
|
||||
fn format_mathml(&self, result: &OcrResult) -> Result<String, String> {
|
||||
if let Some(mathml) = &result.formats.mathml {
|
||||
return Ok(mathml.clone());
|
||||
}
|
||||
|
||||
Err("MathML generation not yet implemented".to_string())
|
||||
}
|
||||
|
||||
fn format_asciimath(&self, result: &OcrResult) -> Result<String, String> {
|
||||
if let Some(asciimath) = &result.formats.asciimath {
|
||||
return Ok(asciimath.clone());
|
||||
}
|
||||
|
||||
Err("AsciiMath conversion not yet implemented".to_string())
|
||||
}
|
||||
|
||||
fn set_format_output(&self, formats: &mut FormatsData, format: OutputFormat, output: String) {
|
||||
match format {
|
||||
OutputFormat::Text => formats.text = Some(output),
|
||||
OutputFormat::LaTeX => formats.latex_normal = Some(output),
|
||||
OutputFormat::LaTeXStyled => formats.latex_styled = Some(output),
|
||||
OutputFormat::Mmd => formats.mmd = Some(output),
|
||||
OutputFormat::Html => formats.html = Some(output),
|
||||
OutputFormat::Smiles => formats.smiles = Some(output),
|
||||
OutputFormat::MathML => formats.mathml = Some(output),
|
||||
OutputFormat::AsciiMath => formats.asciimath = Some(output),
|
||||
OutputFormat::Docx => {} // Binary format, handled separately
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OutputFormatter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Builder for OutputFormatter configuration
|
||||
pub struct FormatterBuilder {
|
||||
config: FormatterConfig,
|
||||
}
|
||||
|
||||
impl FormatterBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
config: FormatterConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn formats(mut self, formats: Vec<OutputFormat>) -> Self {
|
||||
self.config.formats = formats;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn add_format(mut self, format: OutputFormat) -> Self {
|
||||
self.config.formats.push(format);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn pretty(mut self, pretty: bool) -> Self {
|
||||
self.config.pretty = pretty;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn include_confidence(mut self, include: bool) -> Self {
|
||||
self.config.include_confidence = include;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn include_bbox(mut self, include: bool) -> Self {
|
||||
self.config.include_bbox = include;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn math_delimiters(mut self, delimiters: MathDelimiters) -> Self {
|
||||
self.config.math_delimiters = delimiters;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn html_engine(mut self, engine: HtmlEngine) -> Self {
|
||||
self.config.html_engine = engine;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn streaming(mut self, streaming: bool) -> Self {
|
||||
self.config.streaming = streaming;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> OutputFormatter {
|
||||
OutputFormatter::with_config(self.config)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for FormatterBuilder {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_test_result() -> OcrResult {
|
||||
OcrResult {
|
||||
request_id: "test_123".to_string(),
|
||||
version: "3.0".to_string(),
|
||||
image_width: 800,
|
||||
image_height: 600,
|
||||
is_printed: true,
|
||||
is_handwritten: false,
|
||||
auto_rotate_confidence: 0.95,
|
||||
auto_rotate_degrees: 0,
|
||||
confidence: 0.98,
|
||||
confidence_rate: 0.97,
|
||||
formats: FormatsData {
|
||||
text: Some("E = mc^2".to_string()),
|
||||
latex_normal: Some(r"E = mc^2".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
line_data: None,
|
||||
error: None,
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_text() {
|
||||
let formatter = OutputFormatter::new();
|
||||
let result = create_test_result();
|
||||
|
||||
let output = formatter
|
||||
.format_single(&result, OutputFormat::Text)
|
||||
.unwrap();
|
||||
assert_eq!(output, "E = mc^2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_latex() {
|
||||
let formatter = OutputFormatter::new();
|
||||
let result = create_test_result();
|
||||
|
||||
let output = formatter
|
||||
.format_single(&result, OutputFormat::LaTeX)
|
||||
.unwrap();
|
||||
assert!(output.contains("mc^2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builder() {
|
||||
let formatter = FormatterBuilder::new()
|
||||
.add_format(OutputFormat::Text)
|
||||
.add_format(OutputFormat::LaTeX)
|
||||
.pretty(true)
|
||||
.include_confidence(true)
|
||||
.build();
|
||||
|
||||
assert_eq!(formatter.config.formats.len(), 2);
|
||||
assert!(formatter.config.pretty);
|
||||
assert!(formatter.config.include_confidence);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_format() {
|
||||
let formatter = OutputFormatter::new();
|
||||
let results = vec![create_test_result(), create_test_result()];
|
||||
|
||||
let outputs = formatter.format_batch(&results).unwrap();
|
||||
assert_eq!(outputs.len(), 2);
|
||||
}
|
||||
}
|
||||
396
examples/scipix/src/output/html.rs
Normal file
396
examples/scipix/src/output/html.rs
Normal file
@@ -0,0 +1,396 @@
|
||||
//! HTML output formatter with math rendering support
|
||||
|
||||
use super::{HtmlEngine, LineData};
|
||||
|
||||
/// HTML formatter with math rendering
|
||||
pub struct HtmlFormatter {
|
||||
engine: HtmlEngine,
|
||||
css_styling: bool,
|
||||
accessibility: bool,
|
||||
responsive: bool,
|
||||
theme: HtmlTheme,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum HtmlTheme {
|
||||
Light,
|
||||
Dark,
|
||||
Auto,
|
||||
}
|
||||
|
||||
impl HtmlFormatter {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
engine: HtmlEngine::MathJax,
|
||||
css_styling: true,
|
||||
accessibility: true,
|
||||
responsive: true,
|
||||
theme: HtmlTheme::Light,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_engine(mut self, engine: HtmlEngine) -> Self {
|
||||
self.engine = engine;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_styling(mut self, styling: bool) -> Self {
|
||||
self.css_styling = styling;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn accessibility(mut self, enabled: bool) -> Self {
|
||||
self.accessibility = enabled;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn responsive(mut self, enabled: bool) -> Self {
|
||||
self.responsive = enabled;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn theme(mut self, theme: HtmlTheme) -> Self {
|
||||
self.theme = theme;
|
||||
self
|
||||
}
|
||||
|
||||
/// Format content to HTML
|
||||
pub fn format(&self, content: &str, lines: Option<&[LineData]>) -> String {
|
||||
let mut html = String::new();
|
||||
|
||||
// HTML header with math rendering scripts
|
||||
html.push_str(&self.html_header());
|
||||
|
||||
// Body start with theme class
|
||||
html.push_str("<body");
|
||||
if self.css_styling {
|
||||
html.push_str(&format!(r#" class="theme-{:?}""#, self.theme).to_lowercase());
|
||||
}
|
||||
html.push_str(">\n");
|
||||
|
||||
// Main content container
|
||||
html.push_str(r#"<div class="content">"#);
|
||||
html.push_str("\n");
|
||||
|
||||
// Format content
|
||||
if let Some(line_data) = lines {
|
||||
html.push_str(&self.format_lines(line_data));
|
||||
} else {
|
||||
html.push_str(&self.format_text(content));
|
||||
}
|
||||
|
||||
html.push_str("</div>\n");
|
||||
html.push_str("</body>\n</html>");
|
||||
|
||||
html
|
||||
}
|
||||
|
||||
/// Generate HTML header with scripts and styles
|
||||
fn html_header(&self) -> String {
|
||||
let mut header = String::from("<!DOCTYPE html>\n<html lang=\"en\">\n<head>\n");
|
||||
header.push_str(r#" <meta charset="UTF-8">"#);
|
||||
header.push_str("\n");
|
||||
|
||||
if self.responsive {
|
||||
header.push_str(
|
||||
r#" <meta name="viewport" content="width=device-width, initial-scale=1.0">"#,
|
||||
);
|
||||
header.push_str("\n");
|
||||
}
|
||||
|
||||
header.push_str(" <title>Mathematical Content</title>\n");
|
||||
|
||||
// Math rendering scripts
|
||||
match self.engine {
|
||||
HtmlEngine::MathJax => {
|
||||
header.push_str(r#" <script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>"#);
|
||||
header.push_str("\n");
|
||||
header.push_str(r#" <script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>"#);
|
||||
header.push_str("\n");
|
||||
header.push_str(" <script>\n");
|
||||
header.push_str(" MathJax = {\n");
|
||||
header.push_str(" tex: {\n");
|
||||
header.push_str(r#" inlineMath: [['$', '$'], ['\\(', '\\)']],"#);
|
||||
header.push_str("\n");
|
||||
header.push_str(r#" displayMath: [['$$', '$$'], ['\\[', '\\]']]"#);
|
||||
header.push_str("\n");
|
||||
header.push_str(" }\n");
|
||||
header.push_str(" };\n");
|
||||
header.push_str(" </script>\n");
|
||||
}
|
||||
HtmlEngine::KaTeX => {
|
||||
header.push_str(r#" <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.16.9/dist/katex.min.css">"#);
|
||||
header.push_str("\n");
|
||||
header.push_str(r#" <script defer src="https://cdn.jsdelivr.net/npm/katex@0.16.9/dist/katex.min.js"></script>"#);
|
||||
header.push_str("\n");
|
||||
header.push_str(r#" <script defer src="https://cdn.jsdelivr.net/npm/katex@0.16.9/dist/contrib/auto-render.min.js" onload="renderMathInElement(document.body);"></script>"#);
|
||||
header.push_str("\n");
|
||||
}
|
||||
HtmlEngine::Raw => {
|
||||
// No math rendering
|
||||
}
|
||||
}
|
||||
|
||||
// CSS styling
|
||||
if self.css_styling {
|
||||
header.push_str(" <style>\n");
|
||||
header.push_str(&self.generate_css());
|
||||
header.push_str(" </style>\n");
|
||||
}
|
||||
|
||||
header.push_str("</head>\n");
|
||||
header
|
||||
}
|
||||
|
||||
/// Generate CSS styles
|
||||
fn generate_css(&self) -> String {
|
||||
let mut css = String::new();
|
||||
|
||||
css.push_str(" body {\n");
|
||||
css.push_str(" font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;\n");
|
||||
css.push_str(" line-height: 1.6;\n");
|
||||
css.push_str(" max-width: 800px;\n");
|
||||
css.push_str(" margin: 0 auto;\n");
|
||||
css.push_str(" padding: 20px;\n");
|
||||
css.push_str(" }\n");
|
||||
|
||||
// Theme colors
|
||||
match self.theme {
|
||||
HtmlTheme::Light => {
|
||||
css.push_str(" body.theme-light {\n");
|
||||
css.push_str(" background-color: #ffffff;\n");
|
||||
css.push_str(" color: #333333;\n");
|
||||
css.push_str(" }\n");
|
||||
}
|
||||
HtmlTheme::Dark => {
|
||||
css.push_str(" body.theme-dark {\n");
|
||||
css.push_str(" background-color: #1e1e1e;\n");
|
||||
css.push_str(" color: #d4d4d4;\n");
|
||||
css.push_str(" }\n");
|
||||
}
|
||||
HtmlTheme::Auto => {
|
||||
css.push_str(" @media (prefers-color-scheme: dark) {\n");
|
||||
css.push_str(" body { background-color: #1e1e1e; color: #d4d4d4; }\n");
|
||||
css.push_str(" }\n");
|
||||
}
|
||||
}
|
||||
|
||||
css.push_str(" .content { padding: 20px; }\n");
|
||||
css.push_str(" .math-display { text-align: center; margin: 20px 0; }\n");
|
||||
css.push_str(" .math-inline { display: inline; }\n");
|
||||
css.push_str(" .equation-block { margin: 15px 0; padding: 10px; background: #f5f5f5; border-radius: 4px; }\n");
|
||||
css.push_str(" table { border-collapse: collapse; width: 100%; margin: 20px 0; }\n");
|
||||
css.push_str(
|
||||
" th, td { border: 1px solid #ddd; padding: 8px; text-align: left; }\n",
|
||||
);
|
||||
css.push_str(" th { background-color: #f2f2f2; }\n");
|
||||
|
||||
if self.accessibility {
|
||||
css.push_str(" .sr-only { position: absolute; width: 1px; height: 1px; padding: 0; margin: -1px; overflow: hidden; clip: rect(0,0,0,0); border: 0; }\n");
|
||||
}
|
||||
|
||||
css
|
||||
}
|
||||
|
||||
/// Format plain text to HTML
|
||||
fn format_text(&self, text: &str) -> String {
|
||||
let escaped = self.escape_html(text);
|
||||
|
||||
// Convert math delimiters if present
|
||||
let mut html = escaped;
|
||||
|
||||
// Display math $$...$$
|
||||
html = html.replace("$$", "<div class=\"math-display\">$$");
|
||||
html = html.replace("$$", "$$</div>");
|
||||
|
||||
// Inline math $...$
|
||||
// This is simplistic - a real implementation would need proper parsing
|
||||
|
||||
format!("<p>{}</p>", html)
|
||||
}
|
||||
|
||||
/// Format line data to HTML
|
||||
fn format_lines(&self, lines: &[LineData]) -> String {
|
||||
let mut html = String::new();
|
||||
|
||||
for line in lines {
|
||||
match line.line_type.as_str() {
|
||||
"text" => {
|
||||
html.push_str("<p>");
|
||||
html.push_str(&self.escape_html(&line.text));
|
||||
html.push_str("</p>\n");
|
||||
}
|
||||
"math" | "equation" => {
|
||||
let latex = line.latex.as_ref().unwrap_or(&line.text);
|
||||
html.push_str(r#"<div class="math-display">"#);
|
||||
if self.accessibility {
|
||||
html.push_str(&format!(
|
||||
r#"<span class="sr-only">Equation: {}</span>"#,
|
||||
self.escape_html(&line.text)
|
||||
));
|
||||
}
|
||||
html.push_str(&format!("$${}$$", latex));
|
||||
html.push_str("</div>\n");
|
||||
}
|
||||
"inline_math" => {
|
||||
let latex = line.latex.as_ref().unwrap_or(&line.text);
|
||||
html.push_str(&format!(r#"<span class="math-inline">${}$</span>"#, latex));
|
||||
}
|
||||
"heading" => {
|
||||
html.push_str(&format!("<h2>{}</h2>\n", self.escape_html(&line.text)));
|
||||
}
|
||||
"table" => {
|
||||
html.push_str(&self.format_table(&line.text));
|
||||
}
|
||||
"image" => {
|
||||
html.push_str(&format!(
|
||||
r#"<img src="{}" alt="Image" loading="lazy">"#,
|
||||
self.escape_html(&line.text)
|
||||
));
|
||||
html.push_str("\n");
|
||||
}
|
||||
_ => {
|
||||
html.push_str("<p>");
|
||||
html.push_str(&self.escape_html(&line.text));
|
||||
html.push_str("</p>\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
html
|
||||
}
|
||||
|
||||
/// Format table to HTML
|
||||
fn format_table(&self, table: &str) -> String {
|
||||
let mut html = String::from("<table>\n");
|
||||
|
||||
let rows: Vec<&str> = table.lines().collect();
|
||||
for (i, row) in rows.iter().enumerate() {
|
||||
html.push_str(" <tr>\n");
|
||||
|
||||
let cells: Vec<&str> = row
|
||||
.split('|')
|
||||
.map(|s| s.trim())
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect();
|
||||
|
||||
let tag = if i == 0 { "th" } else { "td" };
|
||||
|
||||
for cell in cells {
|
||||
html.push_str(&format!(
|
||||
" <{}>{}</{}>\n",
|
||||
tag,
|
||||
self.escape_html(cell),
|
||||
tag
|
||||
));
|
||||
}
|
||||
|
||||
html.push_str(" </tr>\n");
|
||||
}
|
||||
|
||||
html.push_str("</table>\n");
|
||||
html
|
||||
}
|
||||
|
||||
/// Escape HTML special characters
|
||||
fn escape_html(&self, text: &str) -> String {
|
||||
text.replace('&', "&")
|
||||
.replace('<', "<")
|
||||
.replace('>', ">")
|
||||
.replace('"', """)
|
||||
.replace('\'', "'")
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HtmlFormatter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::output::BoundingBox;
|
||||
|
||||
#[test]
|
||||
fn test_html_header() {
|
||||
let formatter = HtmlFormatter::new().with_engine(HtmlEngine::MathJax);
|
||||
let header = formatter.html_header();
|
||||
|
||||
assert!(header.contains("<!DOCTYPE html>"));
|
||||
assert!(header.contains("MathJax"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_katex_header() {
|
||||
let formatter = HtmlFormatter::new().with_engine(HtmlEngine::KaTeX);
|
||||
let header = formatter.html_header();
|
||||
|
||||
assert!(header.contains("katex"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_escape_html() {
|
||||
let formatter = HtmlFormatter::new();
|
||||
let result = formatter.escape_html("<script>alert('test')</script>");
|
||||
|
||||
assert!(result.contains("<"));
|
||||
assert!(result.contains(">"));
|
||||
assert!(!result.contains("<script>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_lines() {
|
||||
let formatter = HtmlFormatter::new();
|
||||
let lines = vec![
|
||||
LineData {
|
||||
line_type: "text".to_string(),
|
||||
text: "Introduction".to_string(),
|
||||
latex: None,
|
||||
bbox: BoundingBox::new(0.0, 0.0, 100.0, 20.0),
|
||||
confidence: 0.95,
|
||||
words: None,
|
||||
},
|
||||
LineData {
|
||||
line_type: "equation".to_string(),
|
||||
text: "E = mc^2".to_string(),
|
||||
latex: Some(r"E = mc^2".to_string()),
|
||||
bbox: BoundingBox::new(0.0, 25.0, 100.0, 30.0),
|
||||
confidence: 0.98,
|
||||
words: None,
|
||||
},
|
||||
];
|
||||
|
||||
let result = formatter.format_lines(&lines);
|
||||
assert!(result.contains("<p>Introduction</p>"));
|
||||
assert!(result.contains("math-display"));
|
||||
assert!(result.contains("$$"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dark_theme() {
|
||||
let formatter = HtmlFormatter::new().theme(HtmlTheme::Dark);
|
||||
let css = formatter.generate_css();
|
||||
|
||||
assert!(css.contains("theme-dark"));
|
||||
assert!(css.contains("#1e1e1e"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_accessibility() {
|
||||
let formatter = HtmlFormatter::new().accessibility(true);
|
||||
let lines = vec![LineData {
|
||||
line_type: "equation".to_string(),
|
||||
text: "x squared".to_string(),
|
||||
latex: Some("x^2".to_string()),
|
||||
bbox: BoundingBox::new(0.0, 0.0, 100.0, 20.0),
|
||||
confidence: 0.98,
|
||||
words: None,
|
||||
}];
|
||||
|
||||
let result = formatter.format_lines(&lines);
|
||||
assert!(result.contains("sr-only"));
|
||||
assert!(result.contains("Equation:"));
|
||||
}
|
||||
}
|
||||
354
examples/scipix/src/output/json.rs
Normal file
354
examples/scipix/src/output/json.rs
Normal file
@@ -0,0 +1,354 @@
|
||||
//! JSON API response formatter matching Scipix API specification
|
||||
|
||||
use super::{FormatsData, LineData, OcrResult};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Complete API response matching Scipix format
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ApiResponse {
|
||||
/// Request identifier
|
||||
pub request_id: String,
|
||||
|
||||
/// API version
|
||||
pub version: String,
|
||||
|
||||
/// Image information
|
||||
pub image_width: u32,
|
||||
pub image_height: u32,
|
||||
|
||||
/// Detection metadata
|
||||
pub is_printed: bool,
|
||||
pub is_handwritten: bool,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub auto_rotate_confidence: Option<f32>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub auto_rotate_degrees: Option<i32>,
|
||||
|
||||
/// Confidence metrics
|
||||
pub confidence: f32,
|
||||
pub confidence_rate: f32,
|
||||
|
||||
/// Available output formats
|
||||
#[serde(flatten)]
|
||||
pub formats: FormatsData,
|
||||
|
||||
/// Detailed line data
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub line_data: Option<Vec<LineData>>,
|
||||
|
||||
/// Error information
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error_info: Option<ErrorInfo>,
|
||||
|
||||
/// Processing metadata
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<HashMap<String, Value>>,
|
||||
}
|
||||
|
||||
/// Error information structure
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ErrorInfo {
|
||||
pub code: String,
|
||||
pub message: String,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub details: Option<Value>,
|
||||
}
|
||||
|
||||
impl ApiResponse {
|
||||
/// Create response from OCR result
|
||||
pub fn from_ocr_result(result: OcrResult) -> Self {
|
||||
Self {
|
||||
request_id: result.request_id,
|
||||
version: result.version,
|
||||
image_width: result.image_width,
|
||||
image_height: result.image_height,
|
||||
is_printed: result.is_printed,
|
||||
is_handwritten: result.is_handwritten,
|
||||
auto_rotate_confidence: Some(result.auto_rotate_confidence),
|
||||
auto_rotate_degrees: Some(result.auto_rotate_degrees),
|
||||
confidence: result.confidence,
|
||||
confidence_rate: result.confidence_rate,
|
||||
formats: result.formats,
|
||||
line_data: result.line_data,
|
||||
error: result.error,
|
||||
error_info: None,
|
||||
metadata: if result.metadata.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(result.metadata)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Create error response
|
||||
pub fn error(request_id: String, code: &str, message: &str) -> Self {
|
||||
Self {
|
||||
request_id,
|
||||
version: "3.0".to_string(),
|
||||
image_width: 0,
|
||||
image_height: 0,
|
||||
is_printed: false,
|
||||
is_handwritten: false,
|
||||
auto_rotate_confidence: None,
|
||||
auto_rotate_degrees: None,
|
||||
confidence: 0.0,
|
||||
confidence_rate: 0.0,
|
||||
formats: FormatsData::default(),
|
||||
line_data: None,
|
||||
error: Some(message.to_string()),
|
||||
error_info: Some(ErrorInfo {
|
||||
code: code.to_string(),
|
||||
message: message.to_string(),
|
||||
details: None,
|
||||
}),
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert to JSON string
|
||||
pub fn to_json(&self) -> Result<String, String> {
|
||||
serde_json::to_string(self).map_err(|e| format!("JSON serialization error: {}", e))
|
||||
}
|
||||
|
||||
/// Convert to pretty JSON string
|
||||
pub fn to_json_pretty(&self) -> Result<String, String> {
|
||||
serde_json::to_string_pretty(self).map_err(|e| format!("JSON serialization error: {}", e))
|
||||
}
|
||||
|
||||
/// Parse from JSON string
|
||||
pub fn from_json(json: &str) -> Result<Self, String> {
|
||||
serde_json::from_str(json).map_err(|e| format!("JSON parsing error: {}", e))
|
||||
}
|
||||
}
|
||||
|
||||
/// Batch API response
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BatchApiResponse {
|
||||
pub batch_id: String,
|
||||
pub total: usize,
|
||||
pub completed: usize,
|
||||
pub results: Vec<ApiResponse>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub errors: Option<Vec<BatchError>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BatchError {
|
||||
pub index: usize,
|
||||
pub error: ErrorInfo,
|
||||
}
|
||||
|
||||
impl BatchApiResponse {
|
||||
pub fn new(batch_id: String, results: Vec<ApiResponse>) -> Self {
|
||||
let total = results.len();
|
||||
let completed = results.iter().filter(|r| r.error.is_none()).count();
|
||||
|
||||
let errors: Vec<BatchError> = results
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, r)| {
|
||||
r.error_info.as_ref().map(|e| BatchError {
|
||||
index: i,
|
||||
error: e.clone(),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
batch_id,
|
||||
total,
|
||||
completed,
|
||||
results,
|
||||
errors: if errors.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(errors)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_json(&self) -> Result<String, String> {
|
||||
serde_json::to_string(self).map_err(|e| format!("JSON serialization error: {}", e))
|
||||
}
|
||||
|
||||
pub fn to_json_pretty(&self) -> Result<String, String> {
|
||||
serde_json::to_string_pretty(self).map_err(|e| format!("JSON serialization error: {}", e))
|
||||
}
|
||||
}
|
||||
|
||||
/// API request format
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ApiRequest {
|
||||
/// Image source (URL or base64)
|
||||
pub src: String,
|
||||
|
||||
/// Requested output formats
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub formats: Option<Vec<String>>,
|
||||
|
||||
/// OCR options
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ocr: Option<OcrOptions>,
|
||||
|
||||
/// Additional metadata
|
||||
#[serde(flatten)]
|
||||
pub metadata: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OcrOptions {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub math_inline_delimiters: Option<Vec<String>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub math_display_delimiters: Option<Vec<String>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub rm_spaces: Option<bool>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub rm_fonts: Option<bool>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub numbers_default_to_math: Option<bool>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_test_result() -> OcrResult {
|
||||
OcrResult {
|
||||
request_id: "test_123".to_string(),
|
||||
version: "3.0".to_string(),
|
||||
image_width: 800,
|
||||
image_height: 600,
|
||||
is_printed: true,
|
||||
is_handwritten: false,
|
||||
auto_rotate_confidence: 0.95,
|
||||
auto_rotate_degrees: 0,
|
||||
confidence: 0.98,
|
||||
confidence_rate: 0.97,
|
||||
formats: FormatsData {
|
||||
text: Some("E = mc^2".to_string()),
|
||||
latex_normal: Some(r"E = mc^2".to_string()),
|
||||
..Default::default()
|
||||
},
|
||||
line_data: None,
|
||||
error: None,
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_api_response_from_result() {
|
||||
let result = create_test_result();
|
||||
let response = ApiResponse::from_ocr_result(result);
|
||||
|
||||
assert_eq!(response.request_id, "test_123");
|
||||
assert_eq!(response.version, "3.0");
|
||||
assert_eq!(response.confidence, 0.98);
|
||||
assert!(response.formats.text.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_api_response_to_json() {
|
||||
let result = create_test_result();
|
||||
let response = ApiResponse::from_ocr_result(result);
|
||||
let json = response.to_json().unwrap();
|
||||
|
||||
assert!(json.contains("request_id"));
|
||||
assert!(json.contains("test_123"));
|
||||
assert!(json.contains("confidence"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_api_response_round_trip() {
|
||||
let result = create_test_result();
|
||||
let response = ApiResponse::from_ocr_result(result);
|
||||
let json = response.to_json().unwrap();
|
||||
let parsed = ApiResponse::from_json(&json).unwrap();
|
||||
|
||||
assert_eq!(response.request_id, parsed.request_id);
|
||||
assert_eq!(response.confidence, parsed.confidence);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_response() {
|
||||
let response = ApiResponse::error(
|
||||
"test_456".to_string(),
|
||||
"invalid_image",
|
||||
"Image format not supported",
|
||||
);
|
||||
|
||||
assert_eq!(response.request_id, "test_456");
|
||||
assert!(response.error.is_some());
|
||||
assert!(response.error_info.is_some());
|
||||
|
||||
let error_info = response.error_info.unwrap();
|
||||
assert_eq!(error_info.code, "invalid_image");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_response() {
|
||||
let result1 = create_test_result();
|
||||
let result2 = create_test_result();
|
||||
|
||||
let responses = vec![
|
||||
ApiResponse::from_ocr_result(result1),
|
||||
ApiResponse::from_ocr_result(result2),
|
||||
];
|
||||
|
||||
let batch = BatchApiResponse::new("batch_789".to_string(), responses);
|
||||
|
||||
assert_eq!(batch.batch_id, "batch_789");
|
||||
assert_eq!(batch.total, 2);
|
||||
assert_eq!(batch.completed, 2);
|
||||
assert!(batch.errors.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_with_errors() {
|
||||
let success = create_test_result();
|
||||
let error_response =
|
||||
ApiResponse::error("fail_1".to_string(), "timeout", "Processing timeout");
|
||||
|
||||
let responses = vec![ApiResponse::from_ocr_result(success), error_response];
|
||||
|
||||
let batch = BatchApiResponse::new("batch_error".to_string(), responses);
|
||||
|
||||
assert_eq!(batch.total, 2);
|
||||
assert_eq!(batch.completed, 1);
|
||||
assert!(batch.errors.is_some());
|
||||
assert_eq!(batch.errors.unwrap().len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_api_request() {
|
||||
let request = ApiRequest {
|
||||
src: "https://example.com/image.png".to_string(),
|
||||
formats: Some(vec!["text".to_string(), "latex_styled".to_string()]),
|
||||
ocr: Some(OcrOptions {
|
||||
math_inline_delimiters: Some(vec!["$".to_string(), "$".to_string()]),
|
||||
math_display_delimiters: Some(vec!["$$".to_string(), "$$".to_string()]),
|
||||
rm_spaces: Some(true),
|
||||
rm_fonts: None,
|
||||
numbers_default_to_math: Some(false),
|
||||
}),
|
||||
metadata: HashMap::new(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&request).unwrap();
|
||||
assert!(json.contains("src"));
|
||||
assert!(json.contains("formats"));
|
||||
}
|
||||
}
|
||||
430
examples/scipix/src/output/latex.rs
Normal file
430
examples/scipix/src/output/latex.rs
Normal file
@@ -0,0 +1,430 @@
|
||||
//! LaTeX output formatter with styling and package management
|
||||
|
||||
use super::LineData;
|
||||
|
||||
/// LaTeX document formatter
|
||||
#[derive(Clone)]
|
||||
pub struct LaTeXFormatter {
|
||||
packages: Vec<String>,
|
||||
document_class: String,
|
||||
preamble: String,
|
||||
numbered_equations: bool,
|
||||
custom_delimiters: Option<(String, String)>,
|
||||
}
|
||||
|
||||
impl LaTeXFormatter {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
packages: vec!["amsmath".to_string(), "amssymb".to_string()],
|
||||
document_class: "article".to_string(),
|
||||
preamble: String::new(),
|
||||
numbered_equations: false,
|
||||
custom_delimiters: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_packages(mut self, packages: Vec<String>) -> Self {
|
||||
self.packages = packages;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn add_package(mut self, package: String) -> Self {
|
||||
if !self.packages.contains(&package) {
|
||||
self.packages.push(package);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
pub fn document_class(mut self, class: String) -> Self {
|
||||
self.document_class = class;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn preamble(mut self, preamble: String) -> Self {
|
||||
self.preamble = preamble;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn numbered_equations(mut self, numbered: bool) -> Self {
|
||||
self.numbered_equations = numbered;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn custom_delimiters(mut self, start: String, end: String) -> Self {
|
||||
self.custom_delimiters = Some((start, end));
|
||||
self
|
||||
}
|
||||
|
||||
/// Format plain LaTeX content
|
||||
pub fn format(&self, latex: &str) -> String {
|
||||
// Clean up LaTeX if needed
|
||||
let cleaned = self.clean_latex(latex);
|
||||
|
||||
// Apply custom delimiters if specified
|
||||
if let Some((start, end)) = &self.custom_delimiters {
|
||||
format!("{}{}{}", start, cleaned, end)
|
||||
} else {
|
||||
cleaned
|
||||
}
|
||||
}
|
||||
|
||||
/// Format line data to LaTeX
|
||||
pub fn format_lines(&self, lines: &[LineData]) -> String {
|
||||
let mut output = String::new();
|
||||
let mut in_align = false;
|
||||
|
||||
for line in lines {
|
||||
match line.line_type.as_str() {
|
||||
"text" => {
|
||||
if in_align {
|
||||
output.push_str("\\end{align*}\n\n");
|
||||
in_align = false;
|
||||
}
|
||||
output.push_str(&self.escape_text(&line.text));
|
||||
output.push_str("\n\n");
|
||||
}
|
||||
"math" | "equation" => {
|
||||
let latex = line.latex.as_ref().unwrap_or(&line.text);
|
||||
|
||||
if self.numbered_equations {
|
||||
output.push_str("\\begin{equation}\n");
|
||||
output.push_str(latex.trim());
|
||||
output.push_str("\n\\end{equation}\n\n");
|
||||
} else {
|
||||
output.push_str("\\[\n");
|
||||
output.push_str(latex.trim());
|
||||
output.push_str("\n\\]\n\n");
|
||||
}
|
||||
}
|
||||
"inline_math" => {
|
||||
let latex = line.latex.as_ref().unwrap_or(&line.text);
|
||||
output.push_str(&format!("${}$", latex.trim()));
|
||||
}
|
||||
"align" => {
|
||||
if !in_align {
|
||||
output.push_str("\\begin{align*}\n");
|
||||
in_align = true;
|
||||
}
|
||||
let latex = line.latex.as_ref().unwrap_or(&line.text);
|
||||
output.push_str(latex.trim());
|
||||
output.push_str(" \\\\\n");
|
||||
}
|
||||
"table" => {
|
||||
output.push_str(&self.format_table(&line.text));
|
||||
output.push_str("\n\n");
|
||||
}
|
||||
_ => {
|
||||
output.push_str(&line.text);
|
||||
output.push_str("\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if in_align {
|
||||
output.push_str("\\end{align*}\n");
|
||||
}
|
||||
|
||||
output.trim().to_string()
|
||||
}
|
||||
|
||||
/// Format complete LaTeX document
|
||||
pub fn format_document(&self, content: &str) -> String {
|
||||
let mut doc = String::new();
|
||||
|
||||
// Document class
|
||||
doc.push_str(&format!("\\documentclass{{{}}}\n\n", self.document_class));
|
||||
|
||||
// Packages
|
||||
for package in &self.packages {
|
||||
doc.push_str(&format!("\\usepackage{{{}}}\n", package));
|
||||
}
|
||||
doc.push_str("\n");
|
||||
|
||||
// Custom preamble
|
||||
if !self.preamble.is_empty() {
|
||||
doc.push_str(&self.preamble);
|
||||
doc.push_str("\n\n");
|
||||
}
|
||||
|
||||
// Begin document
|
||||
doc.push_str("\\begin{document}\n\n");
|
||||
|
||||
// Content
|
||||
doc.push_str(content);
|
||||
doc.push_str("\n\n");
|
||||
|
||||
// End document
|
||||
doc.push_str("\\end{document}\n");
|
||||
|
||||
doc
|
||||
}
|
||||
|
||||
/// Clean and normalize LaTeX
|
||||
fn clean_latex(&self, latex: &str) -> String {
|
||||
let mut cleaned = latex.to_string();
|
||||
|
||||
// Remove excessive whitespace
|
||||
while cleaned.contains(" ") {
|
||||
cleaned = cleaned.replace(" ", " ");
|
||||
}
|
||||
|
||||
// Normalize line breaks
|
||||
cleaned = cleaned.replace("\r\n", "\n");
|
||||
|
||||
// Ensure proper spacing around operators
|
||||
for op in &["=", "+", "-", r"\times", r"\div"] {
|
||||
let spaced = format!(" {} ", op);
|
||||
cleaned = cleaned.replace(op, &spaced);
|
||||
}
|
||||
|
||||
// Remove duplicate spaces again
|
||||
while cleaned.contains(" ") {
|
||||
cleaned = cleaned.replace(" ", " ");
|
||||
}
|
||||
|
||||
cleaned.trim().to_string()
|
||||
}
|
||||
|
||||
/// Escape special LaTeX characters in text
|
||||
fn escape_text(&self, text: &str) -> String {
|
||||
text.replace('\\', r"\\")
|
||||
.replace('{', r"\{")
|
||||
.replace('}', r"\}")
|
||||
.replace('$', r"\$")
|
||||
.replace('%', r"\%")
|
||||
.replace('_', r"\_")
|
||||
.replace('&', r"\&")
|
||||
.replace('#', r"\#")
|
||||
.replace('^', r"\^")
|
||||
.replace('~', r"\~")
|
||||
}
|
||||
|
||||
/// Format table to LaTeX tabular environment
|
||||
fn format_table(&self, table: &str) -> String {
|
||||
let rows: Vec<&str> = table.lines().collect();
|
||||
if rows.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
// Determine number of columns from first row
|
||||
let num_cols = rows[0].split('|').filter(|s| !s.is_empty()).count();
|
||||
let col_spec = "c".repeat(num_cols);
|
||||
|
||||
let mut output = format!("\\begin{{tabular}}{{{}}}\n", col_spec);
|
||||
output.push_str("\\hline\n");
|
||||
|
||||
for (i, row) in rows.iter().enumerate() {
|
||||
let cells: Vec<&str> = row
|
||||
.split('|')
|
||||
.map(|s| s.trim())
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect();
|
||||
|
||||
output.push_str(&cells.join(" & "));
|
||||
output.push_str(" \\\\\n");
|
||||
|
||||
if i == 0 {
|
||||
output.push_str("\\hline\n");
|
||||
}
|
||||
}
|
||||
|
||||
output.push_str("\\hline\n");
|
||||
output.push_str("\\end{tabular}");
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Convert inline LaTeX to display math
|
||||
pub fn inline_to_display(&self, latex: &str) -> String {
|
||||
if self.numbered_equations {
|
||||
format!("\\begin{{equation}}\n{}\n\\end{{equation}}", latex.trim())
|
||||
} else {
|
||||
format!("\\[\n{}\n\\]", latex.trim())
|
||||
}
|
||||
}
|
||||
|
||||
/// Add equation label
|
||||
pub fn add_label(&self, latex: &str, label: &str) -> String {
|
||||
format!("{}\n\\label{{{}}}", latex.trim(), label)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LaTeXFormatter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Styled LaTeX formatter with predefined templates
|
||||
#[allow(dead_code)]
|
||||
pub struct StyledLaTeXFormatter {
|
||||
base: LaTeXFormatter,
|
||||
style: LaTeXStyle,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum LaTeXStyle {
|
||||
Article,
|
||||
Report,
|
||||
Book,
|
||||
Beamer,
|
||||
Minimal,
|
||||
}
|
||||
|
||||
impl StyledLaTeXFormatter {
|
||||
pub fn new(style: LaTeXStyle) -> Self {
|
||||
let base = match style {
|
||||
LaTeXStyle::Article => LaTeXFormatter::new()
|
||||
.document_class("article".to_string())
|
||||
.with_packages(vec![
|
||||
"amsmath".to_string(),
|
||||
"amssymb".to_string(),
|
||||
"graphicx".to_string(),
|
||||
"hyperref".to_string(),
|
||||
]),
|
||||
LaTeXStyle::Report => LaTeXFormatter::new()
|
||||
.document_class("report".to_string())
|
||||
.with_packages(vec![
|
||||
"amsmath".to_string(),
|
||||
"amssymb".to_string(),
|
||||
"graphicx".to_string(),
|
||||
"hyperref".to_string(),
|
||||
"geometry".to_string(),
|
||||
]),
|
||||
LaTeXStyle::Book => LaTeXFormatter::new()
|
||||
.document_class("book".to_string())
|
||||
.with_packages(vec![
|
||||
"amsmath".to_string(),
|
||||
"amssymb".to_string(),
|
||||
"graphicx".to_string(),
|
||||
"hyperref".to_string(),
|
||||
"geometry".to_string(),
|
||||
"fancyhdr".to_string(),
|
||||
]),
|
||||
LaTeXStyle::Beamer => LaTeXFormatter::new()
|
||||
.document_class("beamer".to_string())
|
||||
.with_packages(vec![
|
||||
"amsmath".to_string(),
|
||||
"amssymb".to_string(),
|
||||
"graphicx".to_string(),
|
||||
]),
|
||||
LaTeXStyle::Minimal => LaTeXFormatter::new()
|
||||
.document_class("article".to_string())
|
||||
.with_packages(vec!["amsmath".to_string()]),
|
||||
};
|
||||
|
||||
Self { base, style }
|
||||
}
|
||||
|
||||
pub fn format_document(
|
||||
&self,
|
||||
content: &str,
|
||||
title: Option<&str>,
|
||||
author: Option<&str>,
|
||||
) -> String {
|
||||
let mut preamble = String::new();
|
||||
|
||||
if let Some(t) = title {
|
||||
preamble.push_str(&format!("\\title{{{}}}\n", t));
|
||||
}
|
||||
if let Some(a) = author {
|
||||
preamble.push_str(&format!("\\author{{{}}}\n", a));
|
||||
}
|
||||
if title.is_some() || author.is_some() {
|
||||
preamble.push_str("\\date{\\today}\n");
|
||||
}
|
||||
|
||||
let formatter = self.base.clone().preamble(preamble);
|
||||
let mut doc = formatter.format_document(content);
|
||||
|
||||
// Add maketitle after \begin{document} if we have title/author
|
||||
if title.is_some() || author.is_some() {
|
||||
doc = doc.replace(
|
||||
"\\begin{document}\n\n",
|
||||
"\\begin{document}\n\n\\maketitle\n\n",
|
||||
);
|
||||
}
|
||||
|
||||
doc
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::output::BoundingBox;
|
||||
|
||||
#[test]
|
||||
fn test_format_simple() {
|
||||
let formatter = LaTeXFormatter::new();
|
||||
let result = formatter.format("E = mc^2");
|
||||
assert!(result.contains("mc^2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_document() {
|
||||
let formatter = LaTeXFormatter::new();
|
||||
let doc = formatter.format_document("E = mc^2");
|
||||
|
||||
assert!(doc.contains(r"\documentclass{article}"));
|
||||
assert!(doc.contains(r"\usepackage{amsmath}"));
|
||||
assert!(doc.contains(r"\begin{document}"));
|
||||
assert!(doc.contains("mc^2"));
|
||||
assert!(doc.contains(r"\end{document}"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_escape_text() {
|
||||
let formatter = LaTeXFormatter::new();
|
||||
let result = formatter.escape_text("Price: $100 & 50%");
|
||||
assert!(result.contains(r"\$100"));
|
||||
assert!(result.contains(r"\&"));
|
||||
assert!(result.contains(r"\%"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_inline_to_display() {
|
||||
let formatter = LaTeXFormatter::new();
|
||||
let result = formatter.inline_to_display("x^2 + y^2 = r^2");
|
||||
assert!(result.contains(r"\["));
|
||||
assert!(result.contains(r"\]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_styled_formatter() {
|
||||
let formatter = StyledLaTeXFormatter::new(LaTeXStyle::Article);
|
||||
let doc = formatter.format_document("Content", Some("My Title"), Some("Author Name"));
|
||||
|
||||
assert!(doc.contains(r"\title{My Title}"));
|
||||
assert!(doc.contains(r"\author{Author Name}"));
|
||||
assert!(doc.contains(r"\maketitle"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_lines() {
|
||||
let formatter = LaTeXFormatter::new();
|
||||
let lines = vec![
|
||||
LineData {
|
||||
line_type: "text".to_string(),
|
||||
text: "Introduction".to_string(),
|
||||
latex: None,
|
||||
bbox: BoundingBox::new(0.0, 0.0, 100.0, 20.0),
|
||||
confidence: 0.95,
|
||||
words: None,
|
||||
},
|
||||
LineData {
|
||||
line_type: "equation".to_string(),
|
||||
text: "E = mc^2".to_string(),
|
||||
latex: Some(r"E = mc^2".to_string()),
|
||||
bbox: BoundingBox::new(0.0, 25.0, 100.0, 30.0),
|
||||
confidence: 0.98,
|
||||
words: None,
|
||||
},
|
||||
];
|
||||
|
||||
let result = formatter.format_lines(&lines);
|
||||
assert!(result.contains("Introduction"));
|
||||
assert!(result.contains(r"\[") || result.contains(r"\begin{equation}"));
|
||||
assert!(result.contains("mc^2"));
|
||||
}
|
||||
}
|
||||
379
examples/scipix/src/output/mmd.rs
Normal file
379
examples/scipix/src/output/mmd.rs
Normal file
@@ -0,0 +1,379 @@
|
||||
//! Scipix Markdown (MMD) formatter
|
||||
//!
|
||||
//! MMD is an enhanced markdown format that supports:
|
||||
//! - Inline and display math with LaTeX
|
||||
//! - Tables with alignment
|
||||
//! - Chemistry notation (SMILES)
|
||||
//! - Image embedding
|
||||
//! - Structured documents
|
||||
|
||||
use super::{LineData, MathDelimiters};
|
||||
|
||||
/// Scipix Markdown formatter
|
||||
pub struct MmdFormatter {
|
||||
delimiters: MathDelimiters,
|
||||
include_metadata: bool,
|
||||
preserve_structure: bool,
|
||||
}
|
||||
|
||||
impl MmdFormatter {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
delimiters: MathDelimiters::default(),
|
||||
include_metadata: false,
|
||||
preserve_structure: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_delimiters(delimiters: MathDelimiters) -> Self {
|
||||
Self {
|
||||
delimiters,
|
||||
include_metadata: false,
|
||||
preserve_structure: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn include_metadata(mut self, include: bool) -> Self {
|
||||
self.include_metadata = include;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn preserve_structure(mut self, preserve: bool) -> Self {
|
||||
self.preserve_structure = preserve;
|
||||
self
|
||||
}
|
||||
|
||||
/// Format line data to MMD
|
||||
pub fn format(&self, lines: &[LineData]) -> String {
|
||||
let mut output = String::new();
|
||||
let mut in_table = false;
|
||||
let mut in_list = false;
|
||||
|
||||
for line in lines {
|
||||
match line.line_type.as_str() {
|
||||
"text" => {
|
||||
if in_table {
|
||||
output.push_str("\n");
|
||||
in_table = false;
|
||||
}
|
||||
if in_list && !line.text.trim_start().starts_with(&['-', '*', '1']) {
|
||||
output.push_str("\n");
|
||||
in_list = false;
|
||||
}
|
||||
output.push_str(&line.text);
|
||||
output.push_str("\n");
|
||||
}
|
||||
"math" | "equation" => {
|
||||
let latex = line.latex.as_ref().unwrap_or(&line.text);
|
||||
let formatted = self.format_math(latex, true); // display mode
|
||||
output.push_str(&formatted);
|
||||
output.push_str("\n\n");
|
||||
}
|
||||
"inline_math" => {
|
||||
let latex = line.latex.as_ref().unwrap_or(&line.text);
|
||||
let formatted = self.format_math(latex, false); // inline mode
|
||||
output.push_str(&formatted);
|
||||
}
|
||||
"table_row" => {
|
||||
if !in_table {
|
||||
in_table = true;
|
||||
}
|
||||
output.push_str(&self.format_table_row(&line.text));
|
||||
output.push_str("\n");
|
||||
}
|
||||
"list_item" => {
|
||||
if !in_list {
|
||||
in_list = true;
|
||||
}
|
||||
output.push_str(&line.text);
|
||||
output.push_str("\n");
|
||||
}
|
||||
"heading" => {
|
||||
output.push_str(&format!("# {}\n\n", line.text));
|
||||
}
|
||||
"image" => {
|
||||
output.push_str(&self.format_image(&line.text));
|
||||
output.push_str("\n\n");
|
||||
}
|
||||
"chemistry" => {
|
||||
let smiles = line.text.trim();
|
||||
output.push_str(&format!("```smiles\n{}\n```\n\n", smiles));
|
||||
}
|
||||
_ => {
|
||||
// Unknown type, output as text
|
||||
output.push_str(&line.text);
|
||||
output.push_str("\n");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
output.trim().to_string()
|
||||
}
|
||||
|
||||
/// Format LaTeX math expression
|
||||
pub fn format_math(&self, latex: &str, display: bool) -> String {
|
||||
if display {
|
||||
format!(
|
||||
"{}\n{}\n{}",
|
||||
self.delimiters.display_start,
|
||||
latex.trim(),
|
||||
self.delimiters.display_end
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"{}{}{}",
|
||||
self.delimiters.inline_start,
|
||||
latex.trim(),
|
||||
self.delimiters.inline_end
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Format table row
|
||||
fn format_table_row(&self, row: &str) -> String {
|
||||
// Basic table formatting - split by | and rejoin
|
||||
let cells: Vec<&str> = row.split('|').map(|s| s.trim()).collect();
|
||||
format!("| {} |", cells.join(" | "))
|
||||
}
|
||||
|
||||
/// Format image reference
|
||||
fn format_image(&self, path: &str) -> String {
|
||||
// Extract alt text and path if available
|
||||
if path.contains('[') && path.contains(']') {
|
||||
path.to_string()
|
||||
} else {
|
||||
format!("", path)
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert plain text with embedded LaTeX to MMD
|
||||
pub fn from_mixed_text(&self, text: &str) -> String {
|
||||
let mut output = String::new();
|
||||
let mut current = String::new();
|
||||
let mut in_math = false;
|
||||
let mut display_math = false;
|
||||
|
||||
let chars: Vec<char> = text.chars().collect();
|
||||
let mut i = 0;
|
||||
|
||||
while i < chars.len() {
|
||||
// Check for display math $$
|
||||
if i + 1 < chars.len() && chars[i] == '$' && chars[i + 1] == '$' {
|
||||
if in_math && display_math {
|
||||
// End display math
|
||||
output.push_str(&self.format_math(¤t, true));
|
||||
current.clear();
|
||||
in_math = false;
|
||||
display_math = false;
|
||||
} else if !in_math {
|
||||
// Start display math
|
||||
if !current.is_empty() {
|
||||
output.push_str(¤t);
|
||||
current.clear();
|
||||
}
|
||||
in_math = true;
|
||||
display_math = true;
|
||||
}
|
||||
i += 2;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check for inline math $
|
||||
if chars[i] == '$' && !display_math {
|
||||
if in_math {
|
||||
// End inline math
|
||||
output.push_str(&self.format_math(¤t, false));
|
||||
current.clear();
|
||||
in_math = false;
|
||||
} else {
|
||||
// Start inline math
|
||||
if !current.is_empty() {
|
||||
output.push_str(¤t);
|
||||
current.clear();
|
||||
}
|
||||
in_math = true;
|
||||
}
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
current.push(chars[i]);
|
||||
i += 1;
|
||||
}
|
||||
|
||||
if !current.is_empty() {
|
||||
output.push_str(¤t);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Format a complete document with frontmatter
|
||||
pub fn format_document(&self, title: &str, content: &str, metadata: Option<&str>) -> String {
|
||||
let mut doc = String::new();
|
||||
|
||||
// Add frontmatter if metadata provided
|
||||
if let Some(meta) = metadata {
|
||||
doc.push_str("---\n");
|
||||
doc.push_str(meta);
|
||||
doc.push_str("\n---\n\n");
|
||||
}
|
||||
|
||||
// Add title
|
||||
doc.push_str(&format!("# {}\n\n", title));
|
||||
|
||||
// Add content
|
||||
doc.push_str(content);
|
||||
|
||||
doc
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MmdFormatter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse MMD back to structured data
|
||||
pub struct MmdParser;
|
||||
|
||||
impl MmdParser {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
|
||||
/// Parse MMD content and extract LaTeX expressions
|
||||
pub fn extract_latex(&self, content: &str) -> Vec<(String, bool)> {
|
||||
let mut expressions = Vec::new();
|
||||
let mut current = String::new();
|
||||
let mut in_math = false;
|
||||
let mut display_math = false;
|
||||
|
||||
let chars: Vec<char> = content.chars().collect();
|
||||
let mut i = 0;
|
||||
|
||||
while i < chars.len() {
|
||||
if i + 1 < chars.len() && chars[i] == '$' && chars[i + 1] == '$' {
|
||||
if in_math && display_math {
|
||||
expressions.push((current.trim().to_string(), true));
|
||||
current.clear();
|
||||
in_math = false;
|
||||
display_math = false;
|
||||
} else if !in_math {
|
||||
in_math = true;
|
||||
display_math = true;
|
||||
}
|
||||
i += 2;
|
||||
} else if chars[i] == '$' && !display_math {
|
||||
if in_math {
|
||||
expressions.push((current.trim().to_string(), false));
|
||||
current.clear();
|
||||
in_math = false;
|
||||
} else {
|
||||
in_math = true;
|
||||
}
|
||||
i += 1;
|
||||
} else if in_math {
|
||||
current.push(chars[i]);
|
||||
i += 1;
|
||||
} else {
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
expressions
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MmdParser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::output::BoundingBox;
|
||||
|
||||
#[test]
|
||||
fn test_format_inline_math() {
|
||||
let formatter = MmdFormatter::new();
|
||||
let result = formatter.format_math("E = mc^2", false);
|
||||
assert_eq!(result, "$E = mc^2$");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_display_math() {
|
||||
let formatter = MmdFormatter::new();
|
||||
let result = formatter.format_math(r"\int_0^1 x^2 dx", true);
|
||||
assert!(result.contains("$$"));
|
||||
assert!(result.contains(r"\int_0^1 x^2 dx"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_lines() {
|
||||
let formatter = MmdFormatter::new();
|
||||
let lines = vec![
|
||||
LineData {
|
||||
line_type: "text".to_string(),
|
||||
text: "The equation".to_string(),
|
||||
latex: None,
|
||||
bbox: BoundingBox::new(0.0, 0.0, 100.0, 20.0),
|
||||
confidence: 0.95,
|
||||
words: None,
|
||||
},
|
||||
LineData {
|
||||
line_type: "math".to_string(),
|
||||
text: "E = mc^2".to_string(),
|
||||
latex: Some(r"E = mc^2".to_string()),
|
||||
bbox: BoundingBox::new(0.0, 25.0, 100.0, 30.0),
|
||||
confidence: 0.98,
|
||||
words: None,
|
||||
},
|
||||
];
|
||||
|
||||
let result = formatter.format(&lines);
|
||||
assert!(result.contains("The equation"));
|
||||
assert!(result.contains("$$"));
|
||||
assert!(result.contains("mc^2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_mixed_text() {
|
||||
let formatter = MmdFormatter::new();
|
||||
let text = "The formula $E = mc^2$ is famous.";
|
||||
let result = formatter.from_mixed_text(text);
|
||||
assert!(result.contains("$E = mc^2$"));
|
||||
assert!(result.contains("famous"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_latex() {
|
||||
let parser = MmdParser::new();
|
||||
let content = "Text with $inline$ and $$display$$ math.";
|
||||
let expressions = parser.extract_latex(content);
|
||||
|
||||
assert_eq!(expressions.len(), 2);
|
||||
assert_eq!(expressions[0].0, "inline");
|
||||
assert!(!expressions[0].1); // inline
|
||||
assert_eq!(expressions[1].0, "display");
|
||||
assert!(expressions[1].1); // display
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_document() {
|
||||
let formatter = MmdFormatter::new();
|
||||
let doc = formatter.format_document(
|
||||
"My Document",
|
||||
"Content here",
|
||||
Some("author: Test\ndate: 2025-01-01"),
|
||||
);
|
||||
|
||||
assert!(doc.contains("---"));
|
||||
assert!(doc.contains("author: Test"));
|
||||
assert!(doc.contains("# My Document"));
|
||||
assert!(doc.contains("Content here"));
|
||||
}
|
||||
}
|
||||
359
examples/scipix/src/output/mod.rs
Normal file
359
examples/scipix/src/output/mod.rs
Normal file
@@ -0,0 +1,359 @@
|
||||
//! Output formatting module for Scipix OCR results
|
||||
//!
|
||||
//! Supports multiple output formats:
|
||||
//! - Text: Plain text extraction
|
||||
//! - LaTeX: Mathematical notation
|
||||
//! - Scipix Markdown (mmd): Enhanced markdown with math
|
||||
//! - MathML: XML-based mathematical markup
|
||||
//! - HTML: Web-ready output with math rendering
|
||||
//! - SMILES: Chemical structure notation
|
||||
//! - DOCX: Microsoft Word format (Office Math ML)
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub mod docx;
|
||||
pub mod formatter;
|
||||
pub mod html;
|
||||
pub mod json;
|
||||
pub mod latex;
|
||||
pub mod mmd;
|
||||
pub mod smiles;
|
||||
|
||||
pub use formatter::{HtmlEngine, MathDelimiters, OutputFormatter};
|
||||
pub use json::ApiResponse;
|
||||
|
||||
/// Output format types supported by Scipix OCR
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum OutputFormat {
|
||||
/// Plain text output
|
||||
Text,
|
||||
/// LaTeX mathematical notation
|
||||
#[serde(rename = "latex_normal")]
|
||||
LaTeX,
|
||||
/// Styled LaTeX with custom packages
|
||||
#[serde(rename = "latex_styled")]
|
||||
LaTeXStyled,
|
||||
/// Mathematical Markup Language
|
||||
#[serde(rename = "mathml")]
|
||||
MathML,
|
||||
/// Scipix Markdown (enhanced markdown)
|
||||
#[serde(rename = "mmd")]
|
||||
Mmd,
|
||||
/// ASCII Math notation
|
||||
#[serde(rename = "asciimath")]
|
||||
AsciiMath,
|
||||
/// HTML with embedded math
|
||||
Html,
|
||||
/// Chemical structure notation
|
||||
#[serde(rename = "smiles")]
|
||||
Smiles,
|
||||
/// Microsoft Word format
|
||||
Docx,
|
||||
}
|
||||
|
||||
impl OutputFormat {
|
||||
/// Get the file extension for this format
|
||||
pub fn extension(&self) -> &'static str {
|
||||
match self {
|
||||
OutputFormat::Text => "txt",
|
||||
OutputFormat::LaTeX | OutputFormat::LaTeXStyled => "tex",
|
||||
OutputFormat::MathML => "xml",
|
||||
OutputFormat::Mmd => "mmd",
|
||||
OutputFormat::AsciiMath => "txt",
|
||||
OutputFormat::Html => "html",
|
||||
OutputFormat::Smiles => "smi",
|
||||
OutputFormat::Docx => "docx",
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the MIME type for this format
|
||||
pub fn mime_type(&self) -> &'static str {
|
||||
match self {
|
||||
OutputFormat::Text | OutputFormat::AsciiMath => "text/plain",
|
||||
OutputFormat::LaTeX | OutputFormat::LaTeXStyled => "application/x-latex",
|
||||
OutputFormat::MathML => "application/mathml+xml",
|
||||
OutputFormat::Mmd => "text/markdown",
|
||||
OutputFormat::Html => "text/html",
|
||||
OutputFormat::Smiles => "chemical/x-daylight-smiles",
|
||||
OutputFormat::Docx => {
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Complete OCR result with all possible output formats
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OcrResult {
|
||||
/// Request identifier
|
||||
pub request_id: String,
|
||||
|
||||
/// Version of the OCR engine
|
||||
pub version: String,
|
||||
|
||||
/// Image dimensions
|
||||
pub image_width: u32,
|
||||
pub image_height: u32,
|
||||
|
||||
/// Processing status
|
||||
pub is_printed: bool,
|
||||
pub is_handwritten: bool,
|
||||
pub auto_rotate_confidence: f32,
|
||||
pub auto_rotate_degrees: i32,
|
||||
|
||||
/// Confidence scores
|
||||
pub confidence: f32,
|
||||
pub confidence_rate: f32,
|
||||
|
||||
/// Available output formats
|
||||
pub formats: FormatsData,
|
||||
|
||||
/// Detailed line and word data
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub line_data: Option<Vec<LineData>>,
|
||||
|
||||
/// Error information if processing failed
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<String>,
|
||||
|
||||
/// Processing metadata
|
||||
#[serde(flatten)]
|
||||
pub metadata: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
/// All available output format data
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct FormatsData {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub text: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub latex_normal: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub latex_styled: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub latex_simplified: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub mathml: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub asciimath: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub mmd: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub html: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub smiles: Option<String>,
|
||||
}
|
||||
|
||||
/// Line-level OCR data with positioning
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LineData {
|
||||
/// Line type: text, math, table, image, etc.
|
||||
#[serde(rename = "type")]
|
||||
pub line_type: String,
|
||||
|
||||
/// Content in various formats
|
||||
pub text: String,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub latex: Option<String>,
|
||||
|
||||
/// Bounding box coordinates
|
||||
pub bbox: BoundingBox,
|
||||
|
||||
/// Confidence score
|
||||
pub confidence: f32,
|
||||
|
||||
/// Word-level data
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub words: Option<Vec<WordData>>,
|
||||
}
|
||||
|
||||
/// Word-level OCR data
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WordData {
|
||||
pub text: String,
|
||||
pub bbox: BoundingBox,
|
||||
pub confidence: f32,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub latex: Option<String>,
|
||||
}
|
||||
|
||||
/// Bounding box coordinates (x, y, width, height)
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub struct BoundingBox {
|
||||
pub x: f32,
|
||||
pub y: f32,
|
||||
pub width: f32,
|
||||
pub height: f32,
|
||||
}
|
||||
|
||||
impl BoundingBox {
|
||||
pub fn new(x: f32, y: f32, width: f32, height: f32) -> Self {
|
||||
Self {
|
||||
x,
|
||||
y,
|
||||
width,
|
||||
height,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn area(&self) -> f32 {
|
||||
self.width * self.height
|
||||
}
|
||||
|
||||
pub fn center(&self) -> (f32, f32) {
|
||||
(self.x + self.width / 2.0, self.y + self.height / 2.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert between output formats
|
||||
pub fn convert_format(
|
||||
content: &str,
|
||||
from: OutputFormat,
|
||||
to: OutputFormat,
|
||||
) -> Result<String, String> {
|
||||
// Simple pass-through for same format
|
||||
if from == to {
|
||||
return Ok(content.to_string());
|
||||
}
|
||||
|
||||
// Format-specific conversions
|
||||
match (from, to) {
|
||||
(OutputFormat::LaTeX, OutputFormat::Text) => {
|
||||
// Strip LaTeX commands for plain text
|
||||
Ok(strip_latex(content))
|
||||
}
|
||||
(OutputFormat::Mmd, OutputFormat::LaTeX) => {
|
||||
// Extract LaTeX from markdown
|
||||
Ok(extract_latex_from_mmd(content))
|
||||
}
|
||||
(OutputFormat::LaTeX, OutputFormat::Html) => {
|
||||
// Wrap LaTeX in HTML with MathJax
|
||||
Ok(format!(
|
||||
r#"<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<script src="https://polyfill.io/v3/polyfill.min.js?features=es6"></script>
|
||||
<script id="MathJax-script" async src="https://cdn.jsdelivr.net/npm/mathjax@3/es5/tex-mml-chtml.js"></script>
|
||||
</head>
|
||||
<body>
|
||||
<p>\({}\)</p>
|
||||
</body>
|
||||
</html>"#,
|
||||
content
|
||||
))
|
||||
}
|
||||
_ => Err(format!(
|
||||
"Conversion from {:?} to {:?} not supported",
|
||||
from, to
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn strip_latex(content: &str) -> String {
|
||||
// Remove common LaTeX commands
|
||||
let mut result = content.to_string();
|
||||
|
||||
// Remove math delimiters
|
||||
result = result.replace("\\(", "").replace("\\)", "");
|
||||
result = result.replace("\\[", "").replace("\\]", "");
|
||||
result = result.replace("$$", "");
|
||||
|
||||
// Remove common commands but keep their content
|
||||
for cmd in &["\\text", "\\mathrm", "\\mathbf", "\\mathit"] {
|
||||
result = result.replace(&format!("{}{}", cmd, "{"), "");
|
||||
}
|
||||
result = result.replace("}", "");
|
||||
|
||||
// Remove standalone commands
|
||||
for cmd in &["\\\\", "\\,", "\\;", "\\:", "\\!", "\\quad", "\\qquad"] {
|
||||
result = result.replace(cmd, " ");
|
||||
}
|
||||
|
||||
result.trim().to_string()
|
||||
}
|
||||
|
||||
fn extract_latex_from_mmd(content: &str) -> String {
|
||||
let mut latex_parts = Vec::new();
|
||||
let mut in_math = false;
|
||||
let mut current = String::new();
|
||||
|
||||
let chars: Vec<char> = content.chars().collect();
|
||||
let mut i = 0;
|
||||
|
||||
while i < chars.len() {
|
||||
if i + 1 < chars.len() && chars[i] == '$' && chars[i + 1] == '$' {
|
||||
if in_math {
|
||||
latex_parts.push(current.clone());
|
||||
current.clear();
|
||||
in_math = false;
|
||||
} else {
|
||||
in_math = true;
|
||||
}
|
||||
i += 2;
|
||||
} else if chars[i] == '$' {
|
||||
in_math = !in_math;
|
||||
i += 1;
|
||||
} else if in_math {
|
||||
current.push(chars[i]);
|
||||
i += 1;
|
||||
} else {
|
||||
i += 1;
|
||||
}
|
||||
}
|
||||
|
||||
latex_parts.join("\n\n")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_output_format_extension() {
|
||||
assert_eq!(OutputFormat::Text.extension(), "txt");
|
||||
assert_eq!(OutputFormat::LaTeX.extension(), "tex");
|
||||
assert_eq!(OutputFormat::Html.extension(), "html");
|
||||
assert_eq!(OutputFormat::Mmd.extension(), "mmd");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_output_format_mime_type() {
|
||||
assert_eq!(OutputFormat::Text.mime_type(), "text/plain");
|
||||
assert_eq!(OutputFormat::LaTeX.mime_type(), "application/x-latex");
|
||||
assert_eq!(OutputFormat::Html.mime_type(), "text/html");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bounding_box() {
|
||||
let bbox = BoundingBox::new(10.0, 20.0, 100.0, 50.0);
|
||||
assert_eq!(bbox.area(), 5000.0);
|
||||
assert_eq!(bbox.center(), (60.0, 45.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_latex() {
|
||||
let input = r"\text{Hello } \mathbf{World}";
|
||||
let output = strip_latex(input);
|
||||
assert!(output.contains("Hello"));
|
||||
assert!(output.contains("World"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_convert_same_format() {
|
||||
let content = "test content";
|
||||
let result = convert_format(content, OutputFormat::Text, OutputFormat::Text).unwrap();
|
||||
assert_eq!(result, content);
|
||||
}
|
||||
}
|
||||
347
examples/scipix/src/output/smiles.rs
Normal file
347
examples/scipix/src/output/smiles.rs
Normal file
@@ -0,0 +1,347 @@
|
||||
//! SMILES (Simplified Molecular Input Line Entry System) generator
|
||||
//!
|
||||
//! Converts chemical structure representations to SMILES notation.
|
||||
//! This is a simplified implementation - full chemistry support requires
|
||||
//! dedicated chemistry libraries like RDKit or OpenBabel.
|
||||
|
||||
use super::OcrResult;
|
||||
|
||||
/// SMILES notation generator for chemical structures
|
||||
pub struct SmilesGenerator {
|
||||
canonical: bool,
|
||||
include_stereochemistry: bool,
|
||||
}
|
||||
|
||||
impl SmilesGenerator {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
canonical: true,
|
||||
include_stereochemistry: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn canonical(mut self, canonical: bool) -> Self {
|
||||
self.canonical = canonical;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn stereochemistry(mut self, include: bool) -> Self {
|
||||
self.include_stereochemistry = include;
|
||||
self
|
||||
}
|
||||
|
||||
/// Generate SMILES from OCR result
|
||||
pub fn generate_from_result(&self, result: &OcrResult) -> Result<String, String> {
|
||||
// Check if SMILES already available
|
||||
if let Some(smiles) = &result.formats.smiles {
|
||||
return Ok(smiles.clone());
|
||||
}
|
||||
|
||||
// Check for chemistry-related content in line data
|
||||
if let Some(line_data) = &result.line_data {
|
||||
for line in line_data {
|
||||
if line.line_type == "chemistry" || line.line_type == "molecule" {
|
||||
return self.parse_chemical_notation(&line.text);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err("No chemical structure data found".to_string())
|
||||
}
|
||||
|
||||
/// Parse chemical notation to SMILES
|
||||
/// This is a placeholder - real implementation needs chemistry parsing
|
||||
fn parse_chemical_notation(&self, notation: &str) -> Result<String, String> {
|
||||
// Check if already SMILES format
|
||||
if self.is_smiles(notation) {
|
||||
return Ok(notation.to_string());
|
||||
}
|
||||
|
||||
// Try to parse common chemical formulas
|
||||
if let Some(smiles) = self.simple_formula_to_smiles(notation) {
|
||||
return Ok(smiles);
|
||||
}
|
||||
|
||||
Err(format!("Cannot convert '{}' to SMILES", notation))
|
||||
}
|
||||
|
||||
/// Check if string is already SMILES notation
|
||||
fn is_smiles(&self, s: &str) -> bool {
|
||||
// Basic SMILES characters
|
||||
let smiles_chars = "CNOPSFClBrI[]()=#@+-0123456789cnops";
|
||||
s.chars().all(|c| smiles_chars.contains(c))
|
||||
}
|
||||
|
||||
/// Convert simple chemical formulas to SMILES
|
||||
fn simple_formula_to_smiles(&self, formula: &str) -> Option<String> {
|
||||
// Common chemical formulas
|
||||
match formula.trim() {
|
||||
"H2O" | "water" => Some("O".to_string()),
|
||||
"CO2" | "carbon dioxide" => Some("O=C=O".to_string()),
|
||||
"CH4" | "methane" => Some("C".to_string()),
|
||||
"C2H6" | "ethane" => Some("CC".to_string()),
|
||||
"C2H5OH" | "ethanol" => Some("CCO".to_string()),
|
||||
"CH3COOH" | "acetic acid" => Some("CC(=O)O".to_string()),
|
||||
"C6H6" | "benzene" => Some("c1ccccc1".to_string()),
|
||||
"C6H12O6" | "glucose" => Some("OC[C@H]1OC(O)[C@H](O)[C@@H](O)[C@@H]1O".to_string()),
|
||||
"NH3" | "ammonia" => Some("N".to_string()),
|
||||
"H2SO4" | "sulfuric acid" => Some("OS(=O)(=O)O".to_string()),
|
||||
"NaCl" | "sodium chloride" => Some("[Na+].[Cl-]".to_string()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate SMILES notation
|
||||
pub fn validate(&self, smiles: &str) -> Result<(), String> {
|
||||
// Basic validation checks
|
||||
|
||||
// Check parentheses balance
|
||||
let mut depth = 0;
|
||||
for c in smiles.chars() {
|
||||
match c {
|
||||
'(' => depth += 1,
|
||||
')' => {
|
||||
depth -= 1;
|
||||
if depth < 0 {
|
||||
return Err("Unbalanced parentheses".to_string());
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
if depth != 0 {
|
||||
return Err("Unbalanced parentheses".to_string());
|
||||
}
|
||||
|
||||
// Check brackets balance
|
||||
let mut depth = 0;
|
||||
for c in smiles.chars() {
|
||||
match c {
|
||||
'[' => depth += 1,
|
||||
']' => {
|
||||
depth -= 1;
|
||||
if depth < 0 {
|
||||
return Err("Unbalanced brackets".to_string());
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
if depth != 0 {
|
||||
return Err("Unbalanced brackets".to_string());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Convert SMILES to molecular formula
|
||||
pub fn to_molecular_formula(&self, smiles: &str) -> Result<String, String> {
|
||||
self.validate(smiles)?;
|
||||
|
||||
// Simplified formula extraction
|
||||
// Real implementation would parse the SMILES properly
|
||||
let mut counts: std::collections::HashMap<char, usize> = std::collections::HashMap::new();
|
||||
|
||||
for c in smiles.chars() {
|
||||
if c.is_alphabetic() && c.is_uppercase() {
|
||||
*counts.entry(c).or_insert(0) += 1;
|
||||
}
|
||||
}
|
||||
|
||||
let mut formula = String::new();
|
||||
// Only use single-character elements for simplicity
|
||||
for element in &['C', 'H', 'N', 'O', 'S', 'P', 'F'] {
|
||||
if let Some(&count) = counts.get(element) {
|
||||
formula.push(*element);
|
||||
if count > 1 {
|
||||
formula.push_str(&count.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if formula.is_empty() {
|
||||
Err("Could not determine molecular formula".to_string())
|
||||
} else {
|
||||
Ok(formula)
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate molecular weight (approximate)
|
||||
pub fn molecular_weight(&self, smiles: &str) -> Result<f32, String> {
|
||||
self.validate(smiles)?;
|
||||
|
||||
// Simplified atomic weights
|
||||
let weights: std::collections::HashMap<char, f32> = [
|
||||
('C', 12.01),
|
||||
('H', 1.008),
|
||||
('N', 14.01),
|
||||
('O', 16.00),
|
||||
('S', 32.07),
|
||||
('P', 30.97),
|
||||
('F', 19.00),
|
||||
]
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let mut total_weight = 0.0;
|
||||
|
||||
for c in smiles.chars() {
|
||||
if let Some(&weight) = weights.get(&c) {
|
||||
total_weight += weight;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(total_weight)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SmilesGenerator {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// SMILES parser for extracting structure information
|
||||
pub struct SmilesParser;
|
||||
|
||||
impl SmilesParser {
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
}
|
||||
|
||||
/// Count atoms in SMILES notation
|
||||
pub fn count_atoms(&self, smiles: &str) -> std::collections::HashMap<String, usize> {
|
||||
let mut counts = std::collections::HashMap::new();
|
||||
|
||||
let mut i = 0;
|
||||
let chars: Vec<char> = smiles.chars().collect();
|
||||
|
||||
while i < chars.len() {
|
||||
if chars[i].is_uppercase() {
|
||||
let mut atom = String::from(chars[i]);
|
||||
|
||||
// Check for two-letter atoms (Cl, Br, etc.)
|
||||
if i + 1 < chars.len() && chars[i + 1].is_lowercase() {
|
||||
atom.push(chars[i + 1]);
|
||||
i += 1;
|
||||
}
|
||||
|
||||
*counts.entry(atom).or_insert(0) += 1;
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
|
||||
counts
|
||||
}
|
||||
|
||||
/// Extract ring information
|
||||
pub fn find_rings(&self, smiles: &str) -> Vec<usize> {
|
||||
let mut rings = Vec::new();
|
||||
|
||||
for (_i, c) in smiles.chars().enumerate() {
|
||||
if c.is_numeric() {
|
||||
if let Some(digit) = c.to_digit(10) {
|
||||
rings.push(digit as usize);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rings
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SmilesParser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_is_smiles() {
|
||||
let gen = SmilesGenerator::new();
|
||||
|
||||
assert!(gen.is_smiles("CCO"));
|
||||
assert!(gen.is_smiles("c1ccccc1"));
|
||||
assert!(gen.is_smiles("CC(=O)O"));
|
||||
assert!(!gen.is_smiles("not smiles!"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_formula_conversion() {
|
||||
let gen = SmilesGenerator::new();
|
||||
|
||||
assert_eq!(gen.simple_formula_to_smiles("H2O"), Some("O".to_string()));
|
||||
assert_eq!(
|
||||
gen.simple_formula_to_smiles("CO2"),
|
||||
Some("O=C=O".to_string())
|
||||
);
|
||||
assert_eq!(gen.simple_formula_to_smiles("CH4"), Some("C".to_string()));
|
||||
assert_eq!(
|
||||
gen.simple_formula_to_smiles("benzene"),
|
||||
Some("c1ccccc1".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_smiles() {
|
||||
let gen = SmilesGenerator::new();
|
||||
|
||||
assert!(gen.validate("CCO").is_ok());
|
||||
assert!(gen.validate("CC(O)C").is_ok());
|
||||
assert!(gen.validate("c1ccccc1").is_ok());
|
||||
|
||||
assert!(gen.validate("CC(O").is_err()); // Unbalanced
|
||||
assert!(gen.validate("CC)O").is_err()); // Unbalanced
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_molecular_formula() {
|
||||
let gen = SmilesGenerator::new();
|
||||
|
||||
let formula = gen.to_molecular_formula("CCO").unwrap();
|
||||
assert!(formula.contains('C'));
|
||||
assert!(formula.contains('O'));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_molecular_weight() {
|
||||
let gen = SmilesGenerator::new();
|
||||
|
||||
// Water: H2O (but SMILES is just "O", representing OH2)
|
||||
let weight = gen.molecular_weight("O").unwrap();
|
||||
assert!(weight > 0.0);
|
||||
|
||||
// Ethanol: C2H6O
|
||||
let weight = gen.molecular_weight("CCO").unwrap();
|
||||
assert!(weight > 30.0); // Should be around 46
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_count_atoms() {
|
||||
let parser = SmilesParser::new();
|
||||
|
||||
let counts = parser.count_atoms("CCO");
|
||||
assert_eq!(counts.get("C"), Some(&2));
|
||||
assert_eq!(counts.get("O"), Some(&1));
|
||||
|
||||
let counts = parser.count_atoms("CC(=O)O");
|
||||
assert_eq!(counts.get("C"), Some(&2));
|
||||
assert_eq!(counts.get("O"), Some(&2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_rings() {
|
||||
let parser = SmilesParser::new();
|
||||
|
||||
let rings = parser.find_rings("c1ccccc1");
|
||||
assert_eq!(rings, vec![1, 1]);
|
||||
|
||||
let rings = parser.find_rings("C1CC1");
|
||||
assert_eq!(rings, vec![1, 1]);
|
||||
}
|
||||
}
|
||||
353
examples/scipix/src/preprocess/deskew.rs
Normal file
353
examples/scipix/src/preprocess/deskew.rs
Normal file
@@ -0,0 +1,353 @@
|
||||
//! Skew detection and correction using Hough transform
|
||||
|
||||
use super::{PreprocessError, Result};
|
||||
use image::{GrayImage, Luma};
|
||||
use imageproc::edges::canny;
|
||||
use imageproc::geometric_transformations::{rotate_about_center, Interpolation};
|
||||
use std::collections::BTreeMap;
|
||||
use std::f32;
|
||||
|
||||
/// Detect skew angle using Hough transform
|
||||
///
|
||||
/// Applies edge detection and Hough transform to find dominant lines,
|
||||
/// then calculates average skew angle.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `image` - Input grayscale image
|
||||
///
|
||||
/// # Returns
|
||||
/// Skew angle in degrees (positive = clockwise)
|
||||
///
|
||||
/// # Example
|
||||
/// ```no_run
|
||||
/// use ruvector_scipix::preprocess::deskew::detect_skew_angle;
|
||||
/// # use image::GrayImage;
|
||||
/// # let image = GrayImage::new(100, 100);
|
||||
/// let angle = detect_skew_angle(&image).unwrap();
|
||||
/// println!("Detected skew: {:.2}°", angle);
|
||||
/// ```
|
||||
pub fn detect_skew_angle(image: &GrayImage) -> Result<f32> {
|
||||
let (width, height) = image.dimensions();
|
||||
|
||||
if width < 20 || height < 20 {
|
||||
return Err(PreprocessError::InvalidParameters(
|
||||
"Image too small for skew detection".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Apply Canny edge detection
|
||||
let edges = canny(image, 50.0, 100.0);
|
||||
|
||||
// Perform Hough transform to detect lines
|
||||
let angles = detect_lines_hough(&edges, width, height)?;
|
||||
|
||||
if angles.is_empty() {
|
||||
return Ok(0.0);
|
||||
}
|
||||
|
||||
// Calculate weighted average angle
|
||||
let total_weight: f32 = angles.values().sum();
|
||||
let weighted_sum: f32 = angles
|
||||
.iter()
|
||||
.map(|(angle_key, weight)| (*angle_key as f32 / 10.0) * weight)
|
||||
.sum();
|
||||
|
||||
let average_angle = if total_weight > 0.0 {
|
||||
weighted_sum / total_weight
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
|
||||
Ok(average_angle)
|
||||
}
|
||||
|
||||
/// Detect lines using Hough transform
|
||||
///
|
||||
/// Returns map of angles to their confidence weights
|
||||
fn detect_lines_hough(edges: &GrayImage, width: u32, height: u32) -> Result<BTreeMap<i32, f32>> {
|
||||
let max_rho = ((width * width + height * height) as f32).sqrt() as usize;
|
||||
let num_angles = 360;
|
||||
|
||||
// Accumulator array for Hough space
|
||||
let mut accumulator = vec![vec![0u32; max_rho]; num_angles];
|
||||
|
||||
// Populate accumulator
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
if edges.get_pixel(x, y)[0] > 128 {
|
||||
// Edge pixel found
|
||||
for theta_idx in 0..num_angles {
|
||||
let theta = (theta_idx as f32) * std::f32::consts::PI / 180.0;
|
||||
let rho = (x as f32) * theta.cos() + (y as f32) * theta.sin();
|
||||
let rho_idx = (rho + max_rho as f32 / 2.0) as usize;
|
||||
|
||||
if rho_idx < max_rho {
|
||||
accumulator[theta_idx][rho_idx] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find peaks in accumulator
|
||||
let mut angle_votes: BTreeMap<i32, f32> = BTreeMap::new();
|
||||
let threshold = (width.min(height) / 10) as u32; // Adaptive threshold
|
||||
|
||||
for theta_idx in 0..num_angles {
|
||||
for rho_idx in 0..max_rho {
|
||||
let votes = accumulator[theta_idx][rho_idx];
|
||||
if votes > threshold {
|
||||
let angle = (theta_idx as f32) - 180.0; // Convert to -180 to 180
|
||||
let normalized_angle = normalize_angle(angle);
|
||||
|
||||
// Only consider angles near horizontal (within ±45°)
|
||||
if normalized_angle.abs() < 45.0 {
|
||||
// Use integer keys for BTreeMap (angle * 10 to preserve precision)
|
||||
let key = (normalized_angle * 10.0) as i32;
|
||||
*angle_votes.entry(key).or_insert(0.0) += votes as f32;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(angle_votes)
|
||||
}
|
||||
|
||||
/// Normalize angle to -45 to +45 degree range
|
||||
fn normalize_angle(angle: f32) -> f32 {
|
||||
let mut normalized = angle % 180.0;
|
||||
if normalized > 90.0 {
|
||||
normalized -= 180.0;
|
||||
} else if normalized < -90.0 {
|
||||
normalized += 180.0;
|
||||
}
|
||||
|
||||
// Clamp to ±45°
|
||||
normalized.clamp(-45.0, 45.0)
|
||||
}
|
||||
|
||||
/// Deskew image using detected skew angle
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `image` - Input grayscale image
|
||||
/// * `angle` - Skew angle in degrees (from detect_skew_angle)
|
||||
///
|
||||
/// # Returns
|
||||
/// Deskewed image with white background fill
|
||||
///
|
||||
/// # Example
|
||||
/// ```no_run
|
||||
/// use ruvector_scipix::preprocess::deskew::{detect_skew_angle, deskew_image};
|
||||
/// # use image::GrayImage;
|
||||
/// # let image = GrayImage::new(100, 100);
|
||||
/// let angle = detect_skew_angle(&image).unwrap();
|
||||
/// let deskewed = deskew_image(&image, angle).unwrap();
|
||||
/// ```
|
||||
pub fn deskew_image(image: &GrayImage, angle: f32) -> Result<GrayImage> {
|
||||
if angle.abs() < 0.1 {
|
||||
// No deskewing needed
|
||||
return Ok(image.clone());
|
||||
}
|
||||
|
||||
let radians = -angle.to_radians(); // Negate for correct direction
|
||||
let deskewed = rotate_about_center(
|
||||
image,
|
||||
radians,
|
||||
Interpolation::Bilinear,
|
||||
Luma([255]), // White background
|
||||
);
|
||||
|
||||
Ok(deskewed)
|
||||
}
|
||||
|
||||
/// Auto-deskew image with confidence threshold
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `image` - Input grayscale image
|
||||
/// * `max_angle` - Maximum angle to correct (degrees)
|
||||
///
|
||||
/// # Returns
|
||||
/// Tuple of (deskewed_image, angle_applied)
|
||||
pub fn auto_deskew(image: &GrayImage, max_angle: f32) -> Result<(GrayImage, f32)> {
|
||||
let angle = detect_skew_angle(image)?;
|
||||
|
||||
if angle.abs() <= max_angle {
|
||||
let deskewed = deskew_image(image, angle)?;
|
||||
Ok((deskewed, angle))
|
||||
} else {
|
||||
// Angle too large, don't correct
|
||||
Ok((image.clone(), 0.0))
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect skew using projection profile method (alternative approach)
|
||||
///
|
||||
/// This is a faster but less accurate method compared to Hough transform
|
||||
pub fn detect_skew_projection(image: &GrayImage) -> Result<f32> {
|
||||
let angles = [
|
||||
-45.0, -30.0, -15.0, -10.0, -5.0, 0.0, 5.0, 10.0, 15.0, 30.0, 45.0,
|
||||
];
|
||||
let mut max_variance = 0.0;
|
||||
let mut best_angle = 0.0;
|
||||
|
||||
for &angle in &angles {
|
||||
let variance = calculate_projection_variance(image, angle);
|
||||
if variance > max_variance {
|
||||
max_variance = variance;
|
||||
best_angle = angle;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(best_angle)
|
||||
}
|
||||
|
||||
/// Calculate projection variance for a given angle
|
||||
fn calculate_projection_variance(image: &GrayImage, angle: f32) -> f32 {
|
||||
let (width, height) = image.dimensions();
|
||||
let rad = angle.to_radians();
|
||||
let cos_a = rad.cos();
|
||||
let sin_a = rad.sin();
|
||||
|
||||
let mut projection = vec![0u32; height as usize];
|
||||
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
let pixel = image.get_pixel(x, y)[0];
|
||||
if pixel < 128 {
|
||||
let proj_y = ((y as f32) * cos_a - (x as f32) * sin_a) as i32;
|
||||
if proj_y >= 0 && proj_y < height as i32 {
|
||||
projection[proj_y as usize] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate variance
|
||||
if projection.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mean = projection.iter().sum::<u32>() as f32 / projection.len() as f32;
|
||||
projection
|
||||
.iter()
|
||||
.map(|&x| {
|
||||
let diff = x as f32 - mean;
|
||||
diff * diff
|
||||
})
|
||||
.sum::<f32>()
|
||||
/ projection.len() as f32
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_test_image() -> GrayImage {
|
||||
let mut img = GrayImage::new(200, 100);
|
||||
|
||||
// Fill with white
|
||||
for pixel in img.pixels_mut() {
|
||||
*pixel = Luma([255]);
|
||||
}
|
||||
|
||||
// Draw some horizontal lines (simulating text)
|
||||
for y in [20, 40, 60, 80] {
|
||||
for x in 10..190 {
|
||||
img.put_pixel(x, y, Luma([0]));
|
||||
}
|
||||
}
|
||||
|
||||
img
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_skew_straight() {
|
||||
let img = create_test_image();
|
||||
let angle = detect_skew_angle(&img);
|
||||
|
||||
assert!(angle.is_ok());
|
||||
let a = angle.unwrap();
|
||||
// Should detect near-zero skew for straight lines
|
||||
assert!(a.abs() < 10.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deskew_image() {
|
||||
let img = create_test_image();
|
||||
|
||||
// Deskew by 5 degrees
|
||||
let deskewed = deskew_image(&img, 5.0);
|
||||
assert!(deskewed.is_ok());
|
||||
|
||||
let result = deskewed.unwrap();
|
||||
assert_eq!(result.dimensions(), img.dimensions());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deskew_no_change() {
|
||||
let img = create_test_image();
|
||||
|
||||
// Deskew by ~0 degrees
|
||||
let deskewed = deskew_image(&img, 0.05);
|
||||
assert!(deskewed.is_ok());
|
||||
|
||||
let result = deskewed.unwrap();
|
||||
assert_eq!(result.dimensions(), img.dimensions());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auto_deskew() {
|
||||
let img = create_test_image();
|
||||
let result = auto_deskew(&img, 15.0);
|
||||
|
||||
assert!(result.is_ok());
|
||||
let (deskewed, angle) = result.unwrap();
|
||||
|
||||
assert_eq!(deskewed.dimensions(), img.dimensions());
|
||||
assert!(angle.abs() <= 15.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_angle() {
|
||||
assert!((normalize_angle(0.0) - 0.0).abs() < 0.01);
|
||||
|
||||
// Test normalization behavior
|
||||
let angle_100 = normalize_angle(100.0);
|
||||
assert!(angle_100.abs() <= 45.0); // Should be clamped to ±45°
|
||||
|
||||
let angle_neg100 = normalize_angle(-100.0);
|
||||
assert!(angle_neg100.abs() <= 45.0); // Should be clamped to ±45°
|
||||
|
||||
assert!((normalize_angle(50.0) - 45.0).abs() < 0.01); // Clamped to 45
|
||||
assert!((normalize_angle(-50.0) - -45.0).abs() < 0.01); // Clamped to -45
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_skew_projection() {
|
||||
let img = create_test_image();
|
||||
let angle = detect_skew_projection(&img);
|
||||
|
||||
assert!(angle.is_ok());
|
||||
let a = angle.unwrap();
|
||||
assert!(a.abs() < 20.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skew_small_image_error() {
|
||||
let small_img = GrayImage::new(10, 10);
|
||||
let result = detect_skew_angle(&small_img);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_projection_variance() {
|
||||
let img = create_test_image();
|
||||
|
||||
let var_0 = calculate_projection_variance(&img, 0.0);
|
||||
let var_30 = calculate_projection_variance(&img, 30.0);
|
||||
|
||||
// Variance at 0° should be higher for horizontal lines
|
||||
assert!(var_0 > 0.0);
|
||||
println!("Variance at 0°: {}, at 30°: {}", var_0, var_30);
|
||||
}
|
||||
}
|
||||
420
examples/scipix/src/preprocess/enhancement.rs
Normal file
420
examples/scipix/src/preprocess/enhancement.rs
Normal file
@@ -0,0 +1,420 @@
|
||||
//! Image enhancement functions for improving OCR accuracy
|
||||
|
||||
use super::{PreprocessError, Result};
|
||||
use image::{GrayImage, Luma};
|
||||
use std::cmp;
|
||||
|
||||
/// Contrast Limited Adaptive Histogram Equalization (CLAHE)
|
||||
///
|
||||
/// Improves local contrast while avoiding over-amplification of noise.
|
||||
/// Divides image into tiles and applies histogram equalization with clipping.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `image` - Input grayscale image
|
||||
/// * `clip_limit` - Contrast clipping limit (typically 2.0-4.0)
|
||||
/// * `tile_size` - Size of contextual regions (typically 8x8 or 16x16)
|
||||
///
|
||||
/// # Returns
|
||||
/// Enhanced image with improved local contrast
|
||||
///
|
||||
/// # Example
|
||||
/// ```no_run
|
||||
/// use ruvector_scipix::preprocess::enhancement::clahe;
|
||||
/// # use image::GrayImage;
|
||||
/// # let image = GrayImage::new(100, 100);
|
||||
/// let enhanced = clahe(&image, 2.0, 8).unwrap();
|
||||
/// ```
|
||||
pub fn clahe(image: &GrayImage, clip_limit: f32, tile_size: u32) -> Result<GrayImage> {
|
||||
if tile_size == 0 || clip_limit <= 0.0 {
|
||||
return Err(PreprocessError::InvalidParameters(
|
||||
"Invalid CLAHE parameters".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let (width, height) = image.dimensions();
|
||||
let mut result = GrayImage::new(width, height);
|
||||
|
||||
let tiles_x = (width + tile_size - 1) / tile_size;
|
||||
let tiles_y = (height + tile_size - 1) / tile_size;
|
||||
|
||||
// Compute histograms and CDFs for each tile
|
||||
let mut tile_cdfs = vec![vec![Vec::new(); tiles_x as usize]; tiles_y as usize];
|
||||
|
||||
for ty in 0..tiles_y {
|
||||
for tx in 0..tiles_x {
|
||||
let x_start = tx * tile_size;
|
||||
let y_start = ty * tile_size;
|
||||
let x_end = cmp::min(x_start + tile_size, width);
|
||||
let y_end = cmp::min(y_start + tile_size, height);
|
||||
|
||||
let cdf = compute_tile_cdf(image, x_start, y_start, x_end, y_end, clip_limit);
|
||||
tile_cdfs[ty as usize][tx as usize] = cdf;
|
||||
}
|
||||
}
|
||||
|
||||
// Interpolate and apply transformation
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
let pixel = image.get_pixel(x, y)[0];
|
||||
|
||||
// Find tile coordinates
|
||||
let tx = (x as f32 / tile_size as f32).floor();
|
||||
let ty = (y as f32 / tile_size as f32).floor();
|
||||
|
||||
// Calculate interpolation weights
|
||||
let x_ratio = (x as f32 / tile_size as f32) - tx;
|
||||
let y_ratio = (y as f32 / tile_size as f32) - ty;
|
||||
|
||||
let tx = tx as usize;
|
||||
let ty = ty as usize;
|
||||
|
||||
// Bilinear interpolation between neighboring tiles
|
||||
let value = if tx < tiles_x as usize - 1 && ty < tiles_y as usize - 1 {
|
||||
let v00 = tile_cdfs[ty][tx][pixel as usize];
|
||||
let v10 = tile_cdfs[ty][tx + 1][pixel as usize];
|
||||
let v01 = tile_cdfs[ty + 1][tx][pixel as usize];
|
||||
let v11 = tile_cdfs[ty + 1][tx + 1][pixel as usize];
|
||||
|
||||
let v0 = v00 * (1.0 - x_ratio) + v10 * x_ratio;
|
||||
let v1 = v01 * (1.0 - x_ratio) + v11 * x_ratio;
|
||||
|
||||
v0 * (1.0 - y_ratio) + v1 * y_ratio
|
||||
} else if tx < tiles_x as usize - 1 {
|
||||
let v0 = tile_cdfs[ty][tx][pixel as usize];
|
||||
let v1 = tile_cdfs[ty][tx + 1][pixel as usize];
|
||||
v0 * (1.0 - x_ratio) + v1 * x_ratio
|
||||
} else if ty < tiles_y as usize - 1 {
|
||||
let v0 = tile_cdfs[ty][tx][pixel as usize];
|
||||
let v1 = tile_cdfs[ty + 1][tx][pixel as usize];
|
||||
v0 * (1.0 - y_ratio) + v1 * y_ratio
|
||||
} else {
|
||||
tile_cdfs[ty][tx][pixel as usize]
|
||||
};
|
||||
|
||||
result.put_pixel(x, y, Luma([(value * 255.0) as u8]));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Compute clipped histogram and CDF for a tile
|
||||
fn compute_tile_cdf(
|
||||
image: &GrayImage,
|
||||
x_start: u32,
|
||||
y_start: u32,
|
||||
x_end: u32,
|
||||
y_end: u32,
|
||||
clip_limit: f32,
|
||||
) -> Vec<f32> {
|
||||
// Calculate histogram
|
||||
let mut histogram = [0u32; 256];
|
||||
let mut pixel_count = 0;
|
||||
|
||||
for y in y_start..y_end {
|
||||
for x in x_start..x_end {
|
||||
let pixel = image.get_pixel(x, y)[0];
|
||||
histogram[pixel as usize] += 1;
|
||||
pixel_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if pixel_count == 0 {
|
||||
return vec![0.0; 256];
|
||||
}
|
||||
|
||||
// Apply contrast limiting
|
||||
let clip_limit_actual = (clip_limit * pixel_count as f32 / 256.0) as u32;
|
||||
let mut clipped_total = 0u32;
|
||||
|
||||
for h in histogram.iter_mut() {
|
||||
if *h > clip_limit_actual {
|
||||
clipped_total += *h - clip_limit_actual;
|
||||
*h = clip_limit_actual;
|
||||
}
|
||||
}
|
||||
|
||||
// Redistribute clipped pixels
|
||||
let redistribute = clipped_total / 256;
|
||||
let remainder = clipped_total % 256;
|
||||
|
||||
for (i, h) in histogram.iter_mut().enumerate() {
|
||||
*h += redistribute;
|
||||
if i < remainder as usize {
|
||||
*h += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute cumulative distribution function (CDF)
|
||||
let mut cdf = vec![0.0; 256];
|
||||
let mut cumsum = 0u32;
|
||||
|
||||
for (i, &h) in histogram.iter().enumerate() {
|
||||
cumsum += h;
|
||||
cdf[i] = cumsum as f32 / pixel_count as f32;
|
||||
}
|
||||
|
||||
cdf
|
||||
}
|
||||
|
||||
/// Normalize brightness across the image
|
||||
///
|
||||
/// Adjusts image to have mean brightness of 128
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `image` - Input grayscale image
|
||||
///
|
||||
/// # Returns
|
||||
/// Brightness-normalized image
|
||||
pub fn normalize_brightness(image: &GrayImage) -> GrayImage {
|
||||
let (width, height) = image.dimensions();
|
||||
let pixel_count = (width * height) as f32;
|
||||
|
||||
// Calculate mean brightness
|
||||
let sum: u32 = image.pixels().map(|p| p[0] as u32).sum();
|
||||
let mean = sum as f32 / pixel_count;
|
||||
|
||||
let target_mean = 128.0;
|
||||
let adjustment = target_mean - mean;
|
||||
|
||||
// Apply adjustment
|
||||
let mut result = GrayImage::new(width, height);
|
||||
for (x, y, pixel) in image.enumerate_pixels() {
|
||||
let adjusted = (pixel[0] as f32 + adjustment).clamp(0.0, 255.0) as u8;
|
||||
result.put_pixel(x, y, Luma([adjusted]));
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Remove shadows from document image
|
||||
///
|
||||
/// Uses morphological operations to estimate and subtract background
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `image` - Input grayscale image
|
||||
///
|
||||
/// # Returns
|
||||
/// Image with reduced shadows
|
||||
pub fn remove_shadows(image: &GrayImage) -> Result<GrayImage> {
|
||||
let (width, height) = image.dimensions();
|
||||
|
||||
// Estimate background using dilation (morphological closing)
|
||||
let kernel_size = (width.min(height) / 20).max(15) as usize;
|
||||
let background = estimate_background(image, kernel_size);
|
||||
|
||||
// Subtract background
|
||||
let mut result = GrayImage::new(width, height);
|
||||
for (x, y, pixel) in image.enumerate_pixels() {
|
||||
let bg = background.get_pixel(x, y)[0] as i32;
|
||||
let fg = pixel[0] as i32;
|
||||
|
||||
// Normalize: (foreground / background) * 255
|
||||
let normalized = if bg > 0 {
|
||||
((fg as f32 / bg as f32) * 255.0).min(255.0) as u8
|
||||
} else {
|
||||
fg as u8
|
||||
};
|
||||
|
||||
result.put_pixel(x, y, Luma([normalized]));
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Estimate background using max filter (dilation)
|
||||
fn estimate_background(image: &GrayImage, kernel_size: usize) -> GrayImage {
|
||||
let (width, height) = image.dimensions();
|
||||
let mut background = GrayImage::new(width, height);
|
||||
let half_kernel = (kernel_size / 2) as i32;
|
||||
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
let mut max_val = 0u8;
|
||||
|
||||
// Find maximum in kernel window
|
||||
for ky in -(half_kernel)..=half_kernel {
|
||||
for kx in -(half_kernel)..=half_kernel {
|
||||
let px = (x as i32 + kx).clamp(0, width as i32 - 1) as u32;
|
||||
let py = (y as i32 + ky).clamp(0, height as i32 - 1) as u32;
|
||||
|
||||
let val = image.get_pixel(px, py)[0];
|
||||
if val > max_val {
|
||||
max_val = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
background.put_pixel(x, y, Luma([max_val]));
|
||||
}
|
||||
}
|
||||
|
||||
background
|
||||
}
|
||||
|
||||
/// Enhance contrast using simple linear stretch
|
||||
///
|
||||
/// Maps min-max range to 0-255
|
||||
pub fn contrast_stretch(image: &GrayImage) -> GrayImage {
|
||||
// Find min and max values
|
||||
let mut min_val = 255u8;
|
||||
let mut max_val = 0u8;
|
||||
|
||||
for pixel in image.pixels() {
|
||||
let val = pixel[0];
|
||||
if val < min_val {
|
||||
min_val = val;
|
||||
}
|
||||
if val > max_val {
|
||||
max_val = val;
|
||||
}
|
||||
}
|
||||
|
||||
if min_val == max_val {
|
||||
return image.clone();
|
||||
}
|
||||
|
||||
// Stretch contrast
|
||||
let (width, height) = image.dimensions();
|
||||
let mut result = GrayImage::new(width, height);
|
||||
let range = (max_val - min_val) as f32;
|
||||
|
||||
for (x, y, pixel) in image.enumerate_pixels() {
|
||||
let val = pixel[0];
|
||||
let stretched = ((val - min_val) as f32 / range * 255.0) as u8;
|
||||
result.put_pixel(x, y, Luma([stretched]));
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_test_image() -> GrayImage {
|
||||
let mut img = GrayImage::new(100, 100);
|
||||
for y in 0..100 {
|
||||
for x in 0..100 {
|
||||
let val = ((x + y) / 2) as u8;
|
||||
img.put_pixel(x, y, Luma([val]));
|
||||
}
|
||||
}
|
||||
img
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clahe() {
|
||||
let img = create_test_image();
|
||||
let enhanced = clahe(&img, 2.0, 8);
|
||||
|
||||
assert!(enhanced.is_ok());
|
||||
let result = enhanced.unwrap();
|
||||
assert_eq!(result.dimensions(), img.dimensions());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clahe_invalid_params() {
|
||||
let img = create_test_image();
|
||||
|
||||
// Invalid tile size
|
||||
let result = clahe(&img, 2.0, 0);
|
||||
assert!(result.is_err());
|
||||
|
||||
// Invalid clip limit
|
||||
let result = clahe(&img, -1.0, 8);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_brightness() {
|
||||
let img = create_test_image();
|
||||
let normalized = normalize_brightness(&img);
|
||||
|
||||
assert_eq!(normalized.dimensions(), img.dimensions());
|
||||
|
||||
// Check that mean is closer to 128
|
||||
let sum: u32 = normalized.pixels().map(|p| p[0] as u32).sum();
|
||||
let mean = sum as f32 / (100.0 * 100.0);
|
||||
|
||||
assert!((mean - 128.0).abs() < 5.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove_shadows() {
|
||||
let img = create_test_image();
|
||||
let result = remove_shadows(&img);
|
||||
|
||||
assert!(result.is_ok());
|
||||
let shadow_removed = result.unwrap();
|
||||
assert_eq!(shadow_removed.dimensions(), img.dimensions());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_contrast_stretch() {
|
||||
// Create low contrast image
|
||||
let mut img = GrayImage::new(100, 100);
|
||||
for y in 0..100 {
|
||||
for x in 0..100 {
|
||||
let val = 100 + ((x + y) / 10) as u8; // Range: 100-119
|
||||
img.put_pixel(x, y, Luma([val]));
|
||||
}
|
||||
}
|
||||
|
||||
let stretched = contrast_stretch(&img);
|
||||
|
||||
// Check that range is now 0-255
|
||||
let mut min_val = 255u8;
|
||||
let mut max_val = 0u8;
|
||||
for pixel in stretched.pixels() {
|
||||
let val = pixel[0];
|
||||
if val < min_val {
|
||||
min_val = val;
|
||||
}
|
||||
if val > max_val {
|
||||
max_val = val;
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(min_val, 0);
|
||||
assert_eq!(max_val, 255);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_contrast_stretch_uniform() {
|
||||
// Uniform image should remain unchanged
|
||||
let mut img = GrayImage::new(50, 50);
|
||||
for pixel in img.pixels_mut() {
|
||||
*pixel = Luma([128]);
|
||||
}
|
||||
|
||||
let stretched = contrast_stretch(&img);
|
||||
|
||||
for pixel in stretched.pixels() {
|
||||
assert_eq!(pixel[0], 128);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_estimate_background() {
|
||||
let img = create_test_image();
|
||||
let background = estimate_background(&img, 5);
|
||||
|
||||
assert_eq!(background.dimensions(), img.dimensions());
|
||||
|
||||
// Background should have higher values (max filter)
|
||||
for (orig, bg) in img.pixels().zip(background.pixels()) {
|
||||
assert!(bg[0] >= orig[0]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clahe_various_tile_sizes() {
|
||||
let img = create_test_image();
|
||||
|
||||
for tile_size in [4, 8, 16, 32] {
|
||||
let result = clahe(&img, 2.0, tile_size);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
}
|
||||
277
examples/scipix/src/preprocess/mod.rs
Normal file
277
examples/scipix/src/preprocess/mod.rs
Normal file
@@ -0,0 +1,277 @@
|
||||
//! Image preprocessing module for OCR pipeline
|
||||
//!
|
||||
//! This module provides comprehensive image preprocessing capabilities including:
|
||||
//! - Image transformations (grayscale, blur, sharpen, threshold)
|
||||
//! - Rotation detection and correction
|
||||
//! - Skew correction (deskewing)
|
||||
//! - Image enhancement (CLAHE, normalization)
|
||||
//! - Text region segmentation
|
||||
//! - Complete preprocessing pipeline with parallel processing
|
||||
|
||||
pub mod deskew;
|
||||
pub mod enhancement;
|
||||
pub mod pipeline;
|
||||
pub mod rotation;
|
||||
pub mod segmentation;
|
||||
pub mod transforms;
|
||||
|
||||
use image::{DynamicImage, GrayImage};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
|
||||
/// Preprocessing error types
|
||||
#[derive(Error, Debug)]
|
||||
pub enum PreprocessError {
|
||||
#[error("Image loading error: {0}")]
|
||||
ImageLoad(String),
|
||||
|
||||
#[error("Invalid parameters: {0}")]
|
||||
InvalidParameters(String),
|
||||
|
||||
#[error("Processing error: {0}")]
|
||||
Processing(String),
|
||||
|
||||
#[error("Segmentation error: {0}")]
|
||||
Segmentation(String),
|
||||
}
|
||||
|
||||
/// Result type for preprocessing operations
|
||||
pub type Result<T> = std::result::Result<T, PreprocessError>;
|
||||
|
||||
/// Preprocessing options for configuring the pipeline
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PreprocessOptions {
|
||||
/// Enable rotation detection and correction
|
||||
pub auto_rotate: bool,
|
||||
|
||||
/// Enable skew detection and correction
|
||||
pub auto_deskew: bool,
|
||||
|
||||
/// Enable contrast enhancement
|
||||
pub enhance_contrast: bool,
|
||||
|
||||
/// Enable denoising
|
||||
pub denoise: bool,
|
||||
|
||||
/// Binarization threshold (None for auto Otsu)
|
||||
pub threshold: Option<u8>,
|
||||
|
||||
/// Enable adaptive thresholding
|
||||
pub adaptive_threshold: bool,
|
||||
|
||||
/// Adaptive threshold window size
|
||||
pub adaptive_window_size: u32,
|
||||
|
||||
/// Target image width (None to keep original)
|
||||
pub target_width: Option<u32>,
|
||||
|
||||
/// Target image height (None to keep original)
|
||||
pub target_height: Option<u32>,
|
||||
|
||||
/// Enable text region detection
|
||||
pub detect_regions: bool,
|
||||
|
||||
/// Gaussian blur sigma for denoising
|
||||
pub blur_sigma: f32,
|
||||
|
||||
/// CLAHE clip limit for contrast enhancement
|
||||
pub clahe_clip_limit: f32,
|
||||
|
||||
/// CLAHE tile size
|
||||
pub clahe_tile_size: u32,
|
||||
}
|
||||
|
||||
impl Default for PreprocessOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
auto_rotate: true,
|
||||
auto_deskew: true,
|
||||
enhance_contrast: true,
|
||||
denoise: true,
|
||||
threshold: None,
|
||||
adaptive_threshold: true,
|
||||
adaptive_window_size: 15,
|
||||
target_width: None,
|
||||
target_height: None,
|
||||
detect_regions: true,
|
||||
blur_sigma: 1.0,
|
||||
clahe_clip_limit: 2.0,
|
||||
clahe_tile_size: 8,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Type of text region
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum RegionType {
|
||||
/// Regular text
|
||||
Text,
|
||||
/// Mathematical equation
|
||||
Math,
|
||||
/// Table
|
||||
Table,
|
||||
/// Figure/Image
|
||||
Figure,
|
||||
/// Unknown/Other
|
||||
Unknown,
|
||||
}
|
||||
|
||||
/// Detected text region with bounding box
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TextRegion {
|
||||
/// Region type
|
||||
pub region_type: RegionType,
|
||||
|
||||
/// Bounding box (x, y, width, height)
|
||||
pub bbox: (u32, u32, u32, u32),
|
||||
|
||||
/// Confidence score (0.0 to 1.0)
|
||||
pub confidence: f32,
|
||||
|
||||
/// Average text height in pixels
|
||||
pub text_height: f32,
|
||||
|
||||
/// Detected baseline angle in degrees
|
||||
pub baseline_angle: f32,
|
||||
}
|
||||
|
||||
/// Main preprocessing function with configurable options
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `image` - Input image to preprocess
|
||||
/// * `options` - Preprocessing configuration options
|
||||
///
|
||||
/// # Returns
|
||||
/// Preprocessed grayscale image ready for OCR
|
||||
///
|
||||
/// # Example
|
||||
/// ```no_run
|
||||
/// use image::open;
|
||||
/// use ruvector_scipix::preprocess::{preprocess, PreprocessOptions};
|
||||
///
|
||||
/// let img = open("document.jpg").unwrap();
|
||||
/// let options = PreprocessOptions::default();
|
||||
/// let processed = preprocess(&img, &options).unwrap();
|
||||
/// ```
|
||||
pub fn preprocess(image: &DynamicImage, options: &PreprocessOptions) -> Result<GrayImage> {
|
||||
pipeline::PreprocessPipeline::builder()
|
||||
.auto_rotate(options.auto_rotate)
|
||||
.auto_deskew(options.auto_deskew)
|
||||
.enhance_contrast(options.enhance_contrast)
|
||||
.denoise(options.denoise)
|
||||
.blur_sigma(options.blur_sigma)
|
||||
.clahe_clip_limit(options.clahe_clip_limit)
|
||||
.clahe_tile_size(options.clahe_tile_size)
|
||||
.threshold(options.threshold)
|
||||
.adaptive_threshold(options.adaptive_threshold)
|
||||
.adaptive_window_size(options.adaptive_window_size)
|
||||
.target_size(options.target_width, options.target_height)
|
||||
.build()
|
||||
.process(image)
|
||||
}
|
||||
|
||||
/// Detect text regions in an image
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `image` - Input grayscale image
|
||||
/// * `min_region_size` - Minimum region size in pixels
|
||||
///
|
||||
/// # Returns
|
||||
/// Vector of detected text regions with metadata
|
||||
///
|
||||
/// # Example
|
||||
/// ```no_run
|
||||
/// use image::open;
|
||||
/// use ruvector_scipix::preprocess::detect_text_regions;
|
||||
///
|
||||
/// let img = open("document.jpg").unwrap().to_luma8();
|
||||
/// let regions = detect_text_regions(&img, 100).unwrap();
|
||||
/// println!("Found {} text regions", regions.len());
|
||||
/// ```
|
||||
pub fn detect_text_regions(image: &GrayImage, min_region_size: u32) -> Result<Vec<TextRegion>> {
|
||||
segmentation::find_text_regions(image, min_region_size)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use image::{Rgb, RgbImage};
|
||||
|
||||
fn create_test_image(width: u32, height: u32) -> DynamicImage {
|
||||
let mut img = RgbImage::new(width, height);
|
||||
|
||||
// Create a simple test pattern
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
let val = ((x + y) % 256) as u8;
|
||||
img.put_pixel(x, y, Rgb([val, val, val]));
|
||||
}
|
||||
}
|
||||
|
||||
DynamicImage::ImageRgb8(img)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_preprocess_default_options() {
|
||||
let img = create_test_image(100, 100);
|
||||
let options = PreprocessOptions::default();
|
||||
|
||||
let result = preprocess(&img, &options);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let processed = result.unwrap();
|
||||
assert_eq!(processed.width(), 100);
|
||||
assert_eq!(processed.height(), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_preprocess_with_resize() {
|
||||
let img = create_test_image(200, 200);
|
||||
let mut options = PreprocessOptions::default();
|
||||
options.target_width = Some(100);
|
||||
options.target_height = Some(100);
|
||||
|
||||
let result = preprocess(&img, &options);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let processed = result.unwrap();
|
||||
assert_eq!(processed.width(), 100);
|
||||
assert_eq!(processed.height(), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_preprocess_options_builder() {
|
||||
let options = PreprocessOptions {
|
||||
auto_rotate: false,
|
||||
auto_deskew: false,
|
||||
enhance_contrast: true,
|
||||
denoise: true,
|
||||
threshold: Some(128),
|
||||
adaptive_threshold: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
assert!(!options.auto_rotate);
|
||||
assert!(!options.auto_deskew);
|
||||
assert!(options.enhance_contrast);
|
||||
assert_eq!(options.threshold, Some(128));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_region_type_serialization() {
|
||||
let region = TextRegion {
|
||||
region_type: RegionType::Math,
|
||||
bbox: (10, 20, 100, 50),
|
||||
confidence: 0.95,
|
||||
text_height: 12.0,
|
||||
baseline_angle: 0.5,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(®ion).unwrap();
|
||||
let deserialized: TextRegion = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert_eq!(deserialized.region_type, RegionType::Math);
|
||||
assert_eq!(deserialized.bbox, (10, 20, 100, 50));
|
||||
assert!((deserialized.confidence - 0.95).abs() < 0.001);
|
||||
}
|
||||
}
|
||||
456
examples/scipix/src/preprocess/pipeline.rs
Normal file
456
examples/scipix/src/preprocess/pipeline.rs
Normal file
@@ -0,0 +1,456 @@
|
||||
//! Complete preprocessing pipeline with builder pattern and parallel processing
|
||||
|
||||
use super::Result;
|
||||
use crate::preprocess::{deskew, enhancement, rotation, transforms};
|
||||
use image::{DynamicImage, GrayImage};
|
||||
use rayon::prelude::*;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Progress callback type
|
||||
pub type ProgressCallback = Arc<dyn Fn(&str, f32) + Send + Sync>;
|
||||
|
||||
/// Complete preprocessing pipeline with configurable steps
|
||||
pub struct PreprocessPipeline {
|
||||
auto_rotate: bool,
|
||||
auto_deskew: bool,
|
||||
enhance_contrast: bool,
|
||||
denoise: bool,
|
||||
blur_sigma: f32,
|
||||
clahe_clip_limit: f32,
|
||||
clahe_tile_size: u32,
|
||||
threshold: Option<u8>,
|
||||
adaptive_threshold: bool,
|
||||
adaptive_window_size: u32,
|
||||
target_width: Option<u32>,
|
||||
target_height: Option<u32>,
|
||||
progress_callback: Option<ProgressCallback>,
|
||||
}
|
||||
|
||||
/// Builder for preprocessing pipeline
|
||||
pub struct PreprocessPipelineBuilder {
|
||||
auto_rotate: bool,
|
||||
auto_deskew: bool,
|
||||
enhance_contrast: bool,
|
||||
denoise: bool,
|
||||
blur_sigma: f32,
|
||||
clahe_clip_limit: f32,
|
||||
clahe_tile_size: u32,
|
||||
threshold: Option<u8>,
|
||||
adaptive_threshold: bool,
|
||||
adaptive_window_size: u32,
|
||||
target_width: Option<u32>,
|
||||
target_height: Option<u32>,
|
||||
progress_callback: Option<ProgressCallback>,
|
||||
}
|
||||
|
||||
impl Default for PreprocessPipelineBuilder {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
auto_rotate: true,
|
||||
auto_deskew: true,
|
||||
enhance_contrast: true,
|
||||
denoise: true,
|
||||
blur_sigma: 1.0,
|
||||
clahe_clip_limit: 2.0,
|
||||
clahe_tile_size: 8,
|
||||
threshold: None,
|
||||
adaptive_threshold: true,
|
||||
adaptive_window_size: 15,
|
||||
target_width: None,
|
||||
target_height: None,
|
||||
progress_callback: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PreprocessPipelineBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn auto_rotate(mut self, enable: bool) -> Self {
|
||||
self.auto_rotate = enable;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn auto_deskew(mut self, enable: bool) -> Self {
|
||||
self.auto_deskew = enable;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn enhance_contrast(mut self, enable: bool) -> Self {
|
||||
self.enhance_contrast = enable;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn denoise(mut self, enable: bool) -> Self {
|
||||
self.denoise = enable;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn blur_sigma(mut self, sigma: f32) -> Self {
|
||||
self.blur_sigma = sigma;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn clahe_clip_limit(mut self, limit: f32) -> Self {
|
||||
self.clahe_clip_limit = limit;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn clahe_tile_size(mut self, size: u32) -> Self {
|
||||
self.clahe_tile_size = size;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn threshold(mut self, threshold: Option<u8>) -> Self {
|
||||
self.threshold = threshold;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn adaptive_threshold(mut self, enable: bool) -> Self {
|
||||
self.adaptive_threshold = enable;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn adaptive_window_size(mut self, size: u32) -> Self {
|
||||
self.adaptive_window_size = size;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn target_size(mut self, width: Option<u32>, height: Option<u32>) -> Self {
|
||||
self.target_width = width;
|
||||
self.target_height = height;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn progress_callback<F>(mut self, callback: F) -> Self
|
||||
where
|
||||
F: Fn(&str, f32) + Send + Sync + 'static,
|
||||
{
|
||||
self.progress_callback = Some(Arc::new(callback));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> PreprocessPipeline {
|
||||
PreprocessPipeline {
|
||||
auto_rotate: self.auto_rotate,
|
||||
auto_deskew: self.auto_deskew,
|
||||
enhance_contrast: self.enhance_contrast,
|
||||
denoise: self.denoise,
|
||||
blur_sigma: self.blur_sigma,
|
||||
clahe_clip_limit: self.clahe_clip_limit,
|
||||
clahe_tile_size: self.clahe_tile_size,
|
||||
threshold: self.threshold,
|
||||
adaptive_threshold: self.adaptive_threshold,
|
||||
adaptive_window_size: self.adaptive_window_size,
|
||||
target_width: self.target_width,
|
||||
target_height: self.target_height,
|
||||
progress_callback: self.progress_callback,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PreprocessPipeline {
|
||||
/// Create a new pipeline builder
|
||||
pub fn builder() -> PreprocessPipelineBuilder {
|
||||
PreprocessPipelineBuilder::new()
|
||||
}
|
||||
|
||||
/// Report progress if callback is set
|
||||
fn report_progress(&self, step: &str, progress: f32) {
|
||||
if let Some(callback) = &self.progress_callback {
|
||||
callback(step, progress);
|
||||
}
|
||||
}
|
||||
|
||||
/// Process a single image through the complete pipeline
|
||||
///
|
||||
/// # Pipeline steps:
|
||||
/// 1. Convert to grayscale
|
||||
/// 2. Detect and correct rotation (if enabled)
|
||||
/// 3. Detect and correct skew (if enabled)
|
||||
/// 4. Enhance contrast with CLAHE (if enabled)
|
||||
/// 5. Denoise with Gaussian blur (if enabled)
|
||||
/// 6. Apply thresholding (binary or adaptive)
|
||||
/// 7. Resize to target dimensions (if specified)
|
||||
pub fn process(&self, image: &DynamicImage) -> Result<GrayImage> {
|
||||
self.report_progress("Starting preprocessing", 0.0);
|
||||
|
||||
// Step 1: Convert to grayscale
|
||||
self.report_progress("Converting to grayscale", 0.1);
|
||||
let mut gray = transforms::to_grayscale(image);
|
||||
|
||||
// Step 2: Auto-rotate
|
||||
if self.auto_rotate {
|
||||
self.report_progress("Detecting rotation", 0.2);
|
||||
let angle = rotation::detect_rotation(&gray)?;
|
||||
|
||||
if angle.abs() > 0.5 {
|
||||
self.report_progress("Correcting rotation", 0.25);
|
||||
gray = rotation::rotate_image(&gray, -angle)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Auto-deskew
|
||||
if self.auto_deskew {
|
||||
self.report_progress("Detecting skew", 0.3);
|
||||
let angle = deskew::detect_skew_angle(&gray)?;
|
||||
|
||||
if angle.abs() > 0.5 {
|
||||
self.report_progress("Correcting skew", 0.35);
|
||||
gray = deskew::deskew_image(&gray, angle)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Enhance contrast
|
||||
if self.enhance_contrast {
|
||||
self.report_progress("Enhancing contrast", 0.5);
|
||||
gray = enhancement::clahe(&gray, self.clahe_clip_limit, self.clahe_tile_size)?;
|
||||
}
|
||||
|
||||
// Step 5: Denoise
|
||||
if self.denoise {
|
||||
self.report_progress("Denoising", 0.6);
|
||||
gray = transforms::gaussian_blur(&gray, self.blur_sigma)?;
|
||||
}
|
||||
|
||||
// Step 6: Thresholding
|
||||
self.report_progress("Applying threshold", 0.7);
|
||||
gray = if self.adaptive_threshold {
|
||||
transforms::adaptive_threshold(&gray, self.adaptive_window_size)?
|
||||
} else if let Some(threshold_val) = self.threshold {
|
||||
transforms::threshold(&gray, threshold_val)
|
||||
} else {
|
||||
// Auto Otsu threshold
|
||||
let threshold_val = transforms::otsu_threshold(&gray)?;
|
||||
transforms::threshold(&gray, threshold_val)
|
||||
};
|
||||
|
||||
// Step 7: Resize
|
||||
if let (Some(width), Some(height)) = (self.target_width, self.target_height) {
|
||||
self.report_progress("Resizing", 0.9);
|
||||
gray = image::imageops::resize(
|
||||
&gray,
|
||||
width,
|
||||
height,
|
||||
image::imageops::FilterType::Lanczos3,
|
||||
);
|
||||
}
|
||||
|
||||
self.report_progress("Preprocessing complete", 1.0);
|
||||
Ok(gray)
|
||||
}
|
||||
|
||||
/// Process multiple images in parallel
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `images` - Vector of images to process
|
||||
///
|
||||
/// # Returns
|
||||
/// Vector of preprocessed images in the same order
|
||||
pub fn process_batch(&self, images: Vec<DynamicImage>) -> Result<Vec<GrayImage>> {
|
||||
images
|
||||
.into_par_iter()
|
||||
.map(|img| self.process(&img))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Process image and return intermediate results from each step
|
||||
///
|
||||
/// Useful for debugging and visualization
|
||||
pub fn process_with_intermediates(
|
||||
&self,
|
||||
image: &DynamicImage,
|
||||
) -> Result<Vec<(String, GrayImage)>> {
|
||||
let mut results = Vec::new();
|
||||
|
||||
// Step 1: Grayscale
|
||||
let mut gray = transforms::to_grayscale(image);
|
||||
results.push(("01_grayscale".to_string(), gray.clone()));
|
||||
|
||||
// Step 2: Rotation
|
||||
if self.auto_rotate {
|
||||
let angle = rotation::detect_rotation(&gray)?;
|
||||
if angle.abs() > 0.5 {
|
||||
gray = rotation::rotate_image(&gray, -angle)?;
|
||||
results.push(("02_rotated".to_string(), gray.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Deskew
|
||||
if self.auto_deskew {
|
||||
let angle = deskew::detect_skew_angle(&gray)?;
|
||||
if angle.abs() > 0.5 {
|
||||
gray = deskew::deskew_image(&gray, angle)?;
|
||||
results.push(("03_deskewed".to_string(), gray.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
// Step 4: Enhancement
|
||||
if self.enhance_contrast {
|
||||
gray = enhancement::clahe(&gray, self.clahe_clip_limit, self.clahe_tile_size)?;
|
||||
results.push(("04_enhanced".to_string(), gray.clone()));
|
||||
}
|
||||
|
||||
// Step 5: Denoise
|
||||
if self.denoise {
|
||||
gray = transforms::gaussian_blur(&gray, self.blur_sigma)?;
|
||||
results.push(("05_denoised".to_string(), gray.clone()));
|
||||
}
|
||||
|
||||
// Step 6: Threshold
|
||||
gray = if self.adaptive_threshold {
|
||||
transforms::adaptive_threshold(&gray, self.adaptive_window_size)?
|
||||
} else if let Some(threshold_val) = self.threshold {
|
||||
transforms::threshold(&gray, threshold_val)
|
||||
} else {
|
||||
let threshold_val = transforms::otsu_threshold(&gray)?;
|
||||
transforms::threshold(&gray, threshold_val)
|
||||
};
|
||||
results.push(("06_thresholded".to_string(), gray.clone()));
|
||||
|
||||
// Step 7: Resize
|
||||
if let (Some(width), Some(height)) = (self.target_width, self.target_height) {
|
||||
gray = image::imageops::resize(
|
||||
&gray,
|
||||
width,
|
||||
height,
|
||||
image::imageops::FilterType::Lanczos3,
|
||||
);
|
||||
results.push(("07_resized".to_string(), gray.clone()));
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use image::{Rgb, RgbImage};
|
||||
|
||||
fn create_test_image() -> DynamicImage {
|
||||
let mut img = RgbImage::new(100, 100);
|
||||
for y in 0..100 {
|
||||
for x in 0..100 {
|
||||
let val = ((x + y) / 2) as u8;
|
||||
img.put_pixel(x, y, Rgb([val, val, val]));
|
||||
}
|
||||
}
|
||||
DynamicImage::ImageRgb8(img)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pipeline_builder() {
|
||||
let pipeline = PreprocessPipeline::builder()
|
||||
.auto_rotate(false)
|
||||
.denoise(true)
|
||||
.blur_sigma(1.5)
|
||||
.build();
|
||||
|
||||
assert!(!pipeline.auto_rotate);
|
||||
assert!(pipeline.denoise);
|
||||
assert!((pipeline.blur_sigma - 1.5).abs() < 0.001);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pipeline_process() {
|
||||
let img = create_test_image();
|
||||
let pipeline = PreprocessPipeline::builder()
|
||||
.auto_rotate(false)
|
||||
.auto_deskew(false)
|
||||
.build();
|
||||
|
||||
let result = pipeline.process(&img);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let processed = result.unwrap();
|
||||
assert_eq!(processed.width(), 100);
|
||||
assert_eq!(processed.height(), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pipeline_with_resize() {
|
||||
let img = create_test_image();
|
||||
let pipeline = PreprocessPipeline::builder()
|
||||
.target_size(Some(50), Some(50))
|
||||
.auto_rotate(false)
|
||||
.auto_deskew(false)
|
||||
.build();
|
||||
|
||||
let result = pipeline.process(&img);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let processed = result.unwrap();
|
||||
assert_eq!(processed.width(), 50);
|
||||
assert_eq!(processed.height(), 50);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pipeline_batch_processing() {
|
||||
let images = vec![
|
||||
create_test_image(),
|
||||
create_test_image(),
|
||||
create_test_image(),
|
||||
];
|
||||
|
||||
let pipeline = PreprocessPipeline::builder()
|
||||
.auto_rotate(false)
|
||||
.auto_deskew(false)
|
||||
.build();
|
||||
|
||||
let results = pipeline.process_batch(images);
|
||||
assert!(results.is_ok());
|
||||
|
||||
let processed = results.unwrap();
|
||||
assert_eq!(processed.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pipeline_intermediates() {
|
||||
let img = create_test_image();
|
||||
let pipeline = PreprocessPipeline::builder()
|
||||
.auto_rotate(false)
|
||||
.auto_deskew(false)
|
||||
.enhance_contrast(true)
|
||||
.denoise(true)
|
||||
.build();
|
||||
|
||||
let result = pipeline.process_with_intermediates(&img);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let intermediates = result.unwrap();
|
||||
assert!(!intermediates.is_empty());
|
||||
assert!(intermediates
|
||||
.iter()
|
||||
.any(|(name, _)| name.contains("grayscale")));
|
||||
assert!(intermediates
|
||||
.iter()
|
||||
.any(|(name, _)| name.contains("thresholded")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_progress_callback() {
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
let progress_steps = Arc::new(Mutex::new(Vec::new()));
|
||||
let progress_clone = Arc::clone(&progress_steps);
|
||||
|
||||
let pipeline = PreprocessPipeline::builder()
|
||||
.auto_rotate(false)
|
||||
.auto_deskew(false)
|
||||
.progress_callback(move |step, _progress| {
|
||||
progress_clone.lock().unwrap().push(step.to_string());
|
||||
})
|
||||
.build();
|
||||
|
||||
let img = create_test_image();
|
||||
let _ = pipeline.process(&img);
|
||||
|
||||
let steps = progress_steps.lock().unwrap();
|
||||
assert!(!steps.is_empty());
|
||||
assert!(steps.iter().any(|s| s.contains("Starting")));
|
||||
assert!(steps.iter().any(|s| s.contains("complete")));
|
||||
}
|
||||
}
|
||||
319
examples/scipix/src/preprocess/rotation.rs
Normal file
319
examples/scipix/src/preprocess/rotation.rs
Normal file
@@ -0,0 +1,319 @@
|
||||
//! Rotation detection and correction using projection profiles
|
||||
|
||||
use super::{PreprocessError, Result};
|
||||
use image::{GrayImage, Luma};
|
||||
use imageproc::geometric_transformations::{rotate_about_center, Interpolation};
|
||||
use std::f32;
|
||||
|
||||
/// Detect rotation angle using projection profile analysis
|
||||
///
|
||||
/// Uses horizontal and vertical projection profiles to detect document rotation.
|
||||
/// Returns angle in degrees (typically in range -45 to +45).
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `image` - Input grayscale image
|
||||
///
|
||||
/// # Returns
|
||||
/// Rotation angle in degrees (positive = clockwise)
|
||||
///
|
||||
/// # Example
|
||||
/// ```no_run
|
||||
/// use ruvector_scipix::preprocess::rotation::detect_rotation;
|
||||
/// # use image::GrayImage;
|
||||
/// # let image = GrayImage::new(100, 100);
|
||||
/// let angle = detect_rotation(&image).unwrap();
|
||||
/// println!("Detected rotation: {:.2}°", angle);
|
||||
/// ```
|
||||
pub fn detect_rotation(image: &GrayImage) -> Result<f32> {
|
||||
let (width, height) = image.dimensions();
|
||||
|
||||
if width < 10 || height < 10 {
|
||||
return Err(PreprocessError::InvalidParameters(
|
||||
"Image too small for rotation detection".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Calculate projection profiles for different angles
|
||||
let angles = [-45.0, -30.0, -15.0, 0.0, 15.0, 30.0, 45.0];
|
||||
let mut max_score = 0.0;
|
||||
let mut best_angle = 0.0;
|
||||
|
||||
for &angle in &angles {
|
||||
let score = calculate_projection_score(image, angle);
|
||||
if score > max_score {
|
||||
max_score = score;
|
||||
best_angle = angle;
|
||||
}
|
||||
}
|
||||
|
||||
// Refine angle with finer search around best candidate
|
||||
let fine_angles: Vec<f32> = (-5..=5).map(|i| best_angle + (i as f32) * 2.0).collect();
|
||||
|
||||
max_score = 0.0;
|
||||
for angle in fine_angles {
|
||||
let score = calculate_projection_score(image, angle);
|
||||
if score > max_score {
|
||||
max_score = score;
|
||||
best_angle = angle;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(best_angle)
|
||||
}
|
||||
|
||||
/// Calculate projection profile score for a given rotation angle
|
||||
///
|
||||
/// Higher scores indicate better alignment with text baselines
|
||||
fn calculate_projection_score(image: &GrayImage, angle: f32) -> f32 {
|
||||
let (width, height) = image.dimensions();
|
||||
|
||||
// For 0 degrees, use direct projection
|
||||
if angle.abs() < 0.1 {
|
||||
return calculate_horizontal_projection_variance(image);
|
||||
}
|
||||
|
||||
// For non-zero angles, calculate projection along rotated axis
|
||||
let rad = angle.to_radians();
|
||||
let cos_a = rad.cos();
|
||||
let sin_a = rad.sin();
|
||||
|
||||
let mut projection = vec![0u32; height as usize];
|
||||
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
let pixel = image.get_pixel(x, y)[0];
|
||||
if pixel < 128 {
|
||||
// Project pixel onto rotated horizontal axis
|
||||
let proj_y = ((y as f32) * cos_a - (x as f32) * sin_a) as i32;
|
||||
if proj_y >= 0 && proj_y < height as i32 {
|
||||
projection[proj_y as usize] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate variance of projection (higher = better alignment)
|
||||
calculate_variance(&projection)
|
||||
}
|
||||
|
||||
/// Calculate horizontal projection variance
|
||||
fn calculate_horizontal_projection_variance(image: &GrayImage) -> f32 {
|
||||
let (width, height) = image.dimensions();
|
||||
let mut projection = vec![0u32; height as usize];
|
||||
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
let pixel = image.get_pixel(x, y)[0];
|
||||
if pixel < 128 {
|
||||
projection[y as usize] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
calculate_variance(&projection)
|
||||
}
|
||||
|
||||
/// Calculate variance of projection profile
|
||||
fn calculate_variance(projection: &[u32]) -> f32 {
|
||||
if projection.is_empty() {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mean = projection.iter().sum::<u32>() as f32 / projection.len() as f32;
|
||||
let variance = projection
|
||||
.iter()
|
||||
.map(|&x| {
|
||||
let diff = x as f32 - mean;
|
||||
diff * diff
|
||||
})
|
||||
.sum::<f32>()
|
||||
/ projection.len() as f32;
|
||||
|
||||
variance
|
||||
}
|
||||
|
||||
/// Rotate image by specified angle
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `image` - Input grayscale image
|
||||
/// * `angle` - Rotation angle in degrees (positive = clockwise)
|
||||
///
|
||||
/// # Returns
|
||||
/// Rotated image with bilinear interpolation
|
||||
///
|
||||
/// # Example
|
||||
/// ```no_run
|
||||
/// use ruvector_scipix::preprocess::rotation::rotate_image;
|
||||
/// # use image::GrayImage;
|
||||
/// # let image = GrayImage::new(100, 100);
|
||||
/// let rotated = rotate_image(&image, 15.0).unwrap();
|
||||
/// ```
|
||||
pub fn rotate_image(image: &GrayImage, angle: f32) -> Result<GrayImage> {
|
||||
if angle.abs() < 0.01 {
|
||||
// No rotation needed
|
||||
return Ok(image.clone());
|
||||
}
|
||||
|
||||
let radians = -angle.to_radians(); // Negate for correct direction
|
||||
let rotated = rotate_about_center(
|
||||
image,
|
||||
radians,
|
||||
Interpolation::Bilinear,
|
||||
Luma([255]), // White background
|
||||
);
|
||||
|
||||
Ok(rotated)
|
||||
}
|
||||
|
||||
/// Detect rotation with confidence score
|
||||
///
|
||||
/// Returns tuple of (angle, confidence) where confidence is 0.0-1.0
|
||||
pub fn detect_rotation_with_confidence(image: &GrayImage) -> Result<(f32, f32)> {
|
||||
let angle = detect_rotation(image)?;
|
||||
|
||||
// Calculate confidence based on projection profile variance difference
|
||||
let current_score = calculate_projection_score(image, angle);
|
||||
let baseline_score = calculate_projection_score(image, 0.0);
|
||||
|
||||
// Confidence is relative improvement over baseline
|
||||
let confidence = if baseline_score > 0.0 {
|
||||
(current_score / baseline_score).min(1.0)
|
||||
} else {
|
||||
0.5 // Default moderate confidence
|
||||
};
|
||||
|
||||
Ok((angle, confidence))
|
||||
}
|
||||
|
||||
/// Auto-rotate image only if confidence is above threshold
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `image` - Input grayscale image
|
||||
/// * `confidence_threshold` - Minimum confidence (0.0-1.0) to apply rotation
|
||||
///
|
||||
/// # Returns
|
||||
/// Tuple of (rotated_image, angle_applied, confidence)
|
||||
pub fn auto_rotate(image: &GrayImage, confidence_threshold: f32) -> Result<(GrayImage, f32, f32)> {
|
||||
let (angle, confidence) = detect_rotation_with_confidence(image)?;
|
||||
|
||||
if confidence >= confidence_threshold && angle.abs() > 0.5 {
|
||||
let rotated = rotate_image(image, -angle)?;
|
||||
Ok((rotated, angle, confidence))
|
||||
} else {
|
||||
Ok((image.clone(), 0.0, confidence))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn create_text_image() -> GrayImage {
|
||||
let mut img = GrayImage::new(200, 100);
|
||||
|
||||
// Fill with white
|
||||
for pixel in img.pixels_mut() {
|
||||
*pixel = Luma([255]);
|
||||
}
|
||||
|
||||
// Draw some horizontal lines (simulating text)
|
||||
for y in [20, 25, 50, 55] {
|
||||
for x in 10..190 {
|
||||
img.put_pixel(x, y, Luma([0]));
|
||||
}
|
||||
}
|
||||
|
||||
img
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_rotation_straight() {
|
||||
let img = create_text_image();
|
||||
let angle = detect_rotation(&img);
|
||||
|
||||
assert!(angle.is_ok());
|
||||
let a = angle.unwrap();
|
||||
// Should detect near-zero rotation
|
||||
assert!(a.abs() < 10.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rotate_image() {
|
||||
let img = create_text_image();
|
||||
|
||||
// Rotate by 15 degrees
|
||||
let rotated = rotate_image(&img, 15.0);
|
||||
assert!(rotated.is_ok());
|
||||
|
||||
let result = rotated.unwrap();
|
||||
assert_eq!(result.dimensions(), img.dimensions());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rotate_no_change() {
|
||||
let img = create_text_image();
|
||||
|
||||
// Rotate by ~0 degrees
|
||||
let rotated = rotate_image(&img, 0.001);
|
||||
assert!(rotated.is_ok());
|
||||
|
||||
let result = rotated.unwrap();
|
||||
assert_eq!(result.dimensions(), img.dimensions());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rotation_confidence() {
|
||||
let img = create_text_image();
|
||||
let result = detect_rotation_with_confidence(&img);
|
||||
|
||||
assert!(result.is_ok());
|
||||
let (angle, confidence) = result.unwrap();
|
||||
|
||||
assert!(confidence >= 0.0 && confidence <= 1.0);
|
||||
println!(
|
||||
"Detected angle: {:.2}°, confidence: {:.2}",
|
||||
angle, confidence
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_auto_rotate_with_threshold() {
|
||||
let img = create_text_image();
|
||||
|
||||
// High threshold - should not rotate if confidence is low
|
||||
let result = auto_rotate(&img, 0.95);
|
||||
assert!(result.is_ok());
|
||||
|
||||
let (rotated, angle, confidence) = result.unwrap();
|
||||
assert_eq!(rotated.dimensions(), img.dimensions());
|
||||
println!(
|
||||
"Auto-rotate: angle={:.2}°, confidence={:.2}",
|
||||
angle, confidence
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_projection_variance() {
|
||||
let projection = vec![10, 50, 100, 50, 10];
|
||||
let variance = calculate_variance(&projection);
|
||||
assert!(variance > 0.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rotation_small_image_error() {
|
||||
let small_img = GrayImage::new(5, 5);
|
||||
let result = detect_rotation(&small_img);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rotation_roundtrip() {
|
||||
let img = create_text_image();
|
||||
|
||||
// Rotate and unrotate
|
||||
let rotated = rotate_image(&img, 30.0).unwrap();
|
||||
let unrotated = rotate_image(&rotated, -30.0).unwrap();
|
||||
|
||||
assert_eq!(unrotated.dimensions(), img.dimensions());
|
||||
}
|
||||
}
|
||||
483
examples/scipix/src/preprocess/segmentation.rs
Normal file
483
examples/scipix/src/preprocess/segmentation.rs
Normal file
@@ -0,0 +1,483 @@
|
||||
//! Text region detection and segmentation
|
||||
|
||||
use super::{RegionType, Result, TextRegion};
|
||||
use image::GrayImage;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
/// Find text regions in a binary or grayscale image
|
||||
///
|
||||
/// Uses connected component analysis and geometric heuristics to identify
|
||||
/// text regions and classify them by type (text, math, table, etc.)
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `image` - Input grayscale or binary image
|
||||
/// * `min_region_size` - Minimum region area in pixels
|
||||
///
|
||||
/// # Returns
|
||||
/// Vector of detected text regions with bounding boxes
|
||||
///
|
||||
/// # Example
|
||||
/// ```no_run
|
||||
/// use ruvector_scipix::preprocess::segmentation::find_text_regions;
|
||||
/// # use image::GrayImage;
|
||||
/// # let image = GrayImage::new(100, 100);
|
||||
/// let regions = find_text_regions(&image, 100).unwrap();
|
||||
/// println!("Found {} regions", regions.len());
|
||||
/// ```
|
||||
pub fn find_text_regions(image: &GrayImage, min_region_size: u32) -> Result<Vec<TextRegion>> {
|
||||
// Find connected components
|
||||
let components = connected_components(image);
|
||||
|
||||
// Extract bounding boxes for each component
|
||||
let bboxes = extract_bounding_boxes(&components);
|
||||
|
||||
// Filter by size and merge overlapping regions
|
||||
let filtered = filter_by_size(bboxes, min_region_size);
|
||||
let merged = merge_overlapping_regions(filtered, 10);
|
||||
|
||||
// Find text lines and group components
|
||||
let text_lines = find_text_lines(image, &merged);
|
||||
|
||||
// Classify regions and create TextRegion objects
|
||||
let regions = classify_regions(image, text_lines);
|
||||
|
||||
Ok(regions)
|
||||
}
|
||||
|
||||
/// Connected component labeling using flood-fill algorithm
|
||||
///
|
||||
/// Returns labeled image where each connected component has a unique ID
|
||||
fn connected_components(image: &GrayImage) -> Vec<Vec<u32>> {
|
||||
let (width, height) = image.dimensions();
|
||||
let mut labels = vec![vec![0u32; width as usize]; height as usize];
|
||||
let mut current_label = 1u32;
|
||||
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
if labels[y as usize][x as usize] == 0 && image.get_pixel(x, y)[0] < 128 {
|
||||
// Found unlabeled foreground pixel, start flood fill
|
||||
flood_fill(image, &mut labels, x, y, current_label);
|
||||
current_label += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
labels
|
||||
}
|
||||
|
||||
/// Flood fill algorithm for connected component labeling
|
||||
fn flood_fill(image: &GrayImage, labels: &mut [Vec<u32>], start_x: u32, start_y: u32, label: u32) {
|
||||
let (width, height) = image.dimensions();
|
||||
let mut stack = vec![(start_x, start_y)];
|
||||
|
||||
while let Some((x, y)) = stack.pop() {
|
||||
if x >= width || y >= height {
|
||||
continue;
|
||||
}
|
||||
|
||||
if labels[y as usize][x as usize] != 0 || image.get_pixel(x, y)[0] >= 128 {
|
||||
continue;
|
||||
}
|
||||
|
||||
labels[y as usize][x as usize] = label;
|
||||
|
||||
// Add 4-connected neighbors
|
||||
if x > 0 {
|
||||
stack.push((x - 1, y));
|
||||
}
|
||||
if x < width - 1 {
|
||||
stack.push((x + 1, y));
|
||||
}
|
||||
if y > 0 {
|
||||
stack.push((x, y - 1));
|
||||
}
|
||||
if y < height - 1 {
|
||||
stack.push((x, y + 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract bounding boxes for each labeled component
|
||||
fn extract_bounding_boxes(labels: &[Vec<u32>]) -> HashMap<u32, (u32, u32, u32, u32)> {
|
||||
let mut bboxes: HashMap<u32, (u32, u32, u32, u32)> = HashMap::new();
|
||||
|
||||
for (y, row) in labels.iter().enumerate() {
|
||||
for (x, &label) in row.iter().enumerate() {
|
||||
if label == 0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let bbox = bboxes
|
||||
.entry(label)
|
||||
.or_insert((x as u32, y as u32, x as u32, y as u32));
|
||||
|
||||
// Update bounding box
|
||||
bbox.0 = bbox.0.min(x as u32); // min_x
|
||||
bbox.1 = bbox.1.min(y as u32); // min_y
|
||||
bbox.2 = bbox.2.max(x as u32); // max_x
|
||||
bbox.3 = bbox.3.max(y as u32); // max_y
|
||||
}
|
||||
}
|
||||
|
||||
// Convert to (x, y, width, height) format
|
||||
bboxes
|
||||
.into_iter()
|
||||
.map(|(label, (min_x, min_y, max_x, max_y))| {
|
||||
let width = max_x - min_x + 1;
|
||||
let height = max_y - min_y + 1;
|
||||
(label, (min_x, min_y, width, height))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Filter regions by minimum size
|
||||
fn filter_by_size(
|
||||
bboxes: HashMap<u32, (u32, u32, u32, u32)>,
|
||||
min_size: u32,
|
||||
) -> Vec<(u32, u32, u32, u32)> {
|
||||
bboxes
|
||||
.into_values()
|
||||
.filter(|(_, _, w, h)| w * h >= min_size)
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Merge overlapping or nearby regions
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `regions` - Vector of bounding boxes (x, y, width, height)
|
||||
/// * `merge_distance` - Maximum distance to merge regions
|
||||
pub fn merge_overlapping_regions(
|
||||
regions: Vec<(u32, u32, u32, u32)>,
|
||||
merge_distance: u32,
|
||||
) -> Vec<(u32, u32, u32, u32)> {
|
||||
if regions.is_empty() {
|
||||
return regions;
|
||||
}
|
||||
|
||||
let mut merged = Vec::new();
|
||||
let mut used = HashSet::new();
|
||||
|
||||
for i in 0..regions.len() {
|
||||
if used.contains(&i) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut current = regions[i];
|
||||
let mut changed = true;
|
||||
|
||||
while changed {
|
||||
changed = false;
|
||||
|
||||
for j in (i + 1)..regions.len() {
|
||||
if used.contains(&j) {
|
||||
continue;
|
||||
}
|
||||
|
||||
if boxes_overlap_or_close(¤t, ®ions[j], merge_distance) {
|
||||
current = merge_boxes(¤t, ®ions[j]);
|
||||
used.insert(j);
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
merged.push(current);
|
||||
used.insert(i);
|
||||
}
|
||||
|
||||
merged
|
||||
}
|
||||
|
||||
/// Check if two bounding boxes overlap or are close
|
||||
fn boxes_overlap_or_close(
|
||||
box1: &(u32, u32, u32, u32),
|
||||
box2: &(u32, u32, u32, u32),
|
||||
distance: u32,
|
||||
) -> bool {
|
||||
let (x1, y1, w1, h1) = *box1;
|
||||
let (x2, y2, w2, h2) = *box2;
|
||||
|
||||
let x1_end = x1 + w1;
|
||||
let y1_end = y1 + h1;
|
||||
let x2_end = x2 + w2;
|
||||
let y2_end = y2 + h2;
|
||||
|
||||
// Check for overlap or proximity
|
||||
let x_overlap = (x1 <= x2_end + distance) && (x2 <= x1_end + distance);
|
||||
let y_overlap = (y1 <= y2_end + distance) && (y2 <= y1_end + distance);
|
||||
|
||||
x_overlap && y_overlap
|
||||
}
|
||||
|
||||
/// Merge two bounding boxes
|
||||
fn merge_boxes(box1: &(u32, u32, u32, u32), box2: &(u32, u32, u32, u32)) -> (u32, u32, u32, u32) {
|
||||
let (x1, y1, w1, h1) = *box1;
|
||||
let (x2, y2, w2, h2) = *box2;
|
||||
|
||||
let min_x = x1.min(x2);
|
||||
let min_y = y1.min(y2);
|
||||
let max_x = (x1 + w1).max(x2 + w2);
|
||||
let max_y = (y1 + h1).max(y2 + h2);
|
||||
|
||||
(min_x, min_y, max_x - min_x, max_y - min_y)
|
||||
}
|
||||
|
||||
/// Find text lines using projection profiles
|
||||
///
|
||||
/// Groups regions into lines based on vertical alignment
|
||||
pub fn find_text_lines(
|
||||
_image: &GrayImage,
|
||||
regions: &[(u32, u32, u32, u32)],
|
||||
) -> Vec<Vec<(u32, u32, u32, u32)>> {
|
||||
if regions.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
// Sort regions by y-coordinate
|
||||
let mut sorted_regions = regions.to_vec();
|
||||
sorted_regions.sort_by_key(|r| r.1);
|
||||
|
||||
let mut lines = Vec::new();
|
||||
let mut current_line = vec![sorted_regions[0]];
|
||||
|
||||
for region in sorted_regions.iter().skip(1) {
|
||||
let (_, y, _, h) = region;
|
||||
let (_, prev_y, _, prev_h) = current_line.last().unwrap();
|
||||
|
||||
// Check if region is on the same line (vertical overlap)
|
||||
let line_height = (*prev_h).max(*h);
|
||||
let distance = if y > prev_y { y - prev_y } else { prev_y - y };
|
||||
|
||||
if distance < line_height / 2 {
|
||||
current_line.push(*region);
|
||||
} else {
|
||||
lines.push(current_line.clone());
|
||||
current_line = vec![*region];
|
||||
}
|
||||
}
|
||||
|
||||
if !current_line.is_empty() {
|
||||
lines.push(current_line);
|
||||
}
|
||||
|
||||
lines
|
||||
}
|
||||
|
||||
/// Classify regions by type (text, math, table, etc.)
|
||||
fn classify_regions(
|
||||
image: &GrayImage,
|
||||
text_lines: Vec<Vec<(u32, u32, u32, u32)>>,
|
||||
) -> Vec<TextRegion> {
|
||||
let mut regions = Vec::new();
|
||||
|
||||
for line in text_lines {
|
||||
for bbox in line {
|
||||
let (x, y, width, height) = bbox;
|
||||
|
||||
// Calculate features for classification
|
||||
let aspect_ratio = width as f32 / height as f32;
|
||||
let density = calculate_density(image, bbox);
|
||||
|
||||
// Simple heuristic classification
|
||||
let region_type = if aspect_ratio > 10.0 {
|
||||
// Very wide region might be a table or figure caption
|
||||
RegionType::Table
|
||||
} else if aspect_ratio < 0.5 && height > 50 {
|
||||
// Tall region might be a figure
|
||||
RegionType::Figure
|
||||
} else if density > 0.3 && height < 30 {
|
||||
// Dense, small region likely math
|
||||
RegionType::Math
|
||||
} else {
|
||||
// Default to text
|
||||
RegionType::Text
|
||||
};
|
||||
|
||||
regions.push(TextRegion {
|
||||
region_type,
|
||||
bbox: (x, y, width, height),
|
||||
confidence: 0.8, // Default confidence
|
||||
text_height: height as f32,
|
||||
baseline_angle: 0.0,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
regions
|
||||
}
|
||||
|
||||
/// Calculate pixel density in a region
|
||||
fn calculate_density(image: &GrayImage, bbox: (u32, u32, u32, u32)) -> f32 {
|
||||
let (x, y, width, height) = bbox;
|
||||
let total_pixels = (width * height) as f32;
|
||||
|
||||
if total_pixels == 0.0 {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
let mut foreground_pixels = 0;
|
||||
|
||||
for py in y..(y + height) {
|
||||
for px in x..(x + width) {
|
||||
if image.get_pixel(px, py)[0] < 128 {
|
||||
foreground_pixels += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
foreground_pixels as f32 / total_pixels
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use image::Luma;
|
||||
|
||||
fn create_test_image_with_rectangles() -> GrayImage {
|
||||
let mut img = GrayImage::new(200, 200);
|
||||
|
||||
// Fill with white
|
||||
for pixel in img.pixels_mut() {
|
||||
*pixel = Luma([255]);
|
||||
}
|
||||
|
||||
// Draw some black rectangles (simulating text regions)
|
||||
for y in 20..40 {
|
||||
for x in 20..100 {
|
||||
img.put_pixel(x, y, Luma([0]));
|
||||
}
|
||||
}
|
||||
|
||||
for y in 60..80 {
|
||||
for x in 20..120 {
|
||||
img.put_pixel(x, y, Luma([0]));
|
||||
}
|
||||
}
|
||||
|
||||
for y in 100..120 {
|
||||
for x in 20..80 {
|
||||
img.put_pixel(x, y, Luma([0]));
|
||||
}
|
||||
}
|
||||
|
||||
img
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_text_regions() {
|
||||
let img = create_test_image_with_rectangles();
|
||||
let regions = find_text_regions(&img, 100);
|
||||
|
||||
assert!(regions.is_ok());
|
||||
let r = regions.unwrap();
|
||||
|
||||
// Should find at least 3 regions
|
||||
assert!(r.len() >= 3);
|
||||
|
||||
for region in r {
|
||||
println!("Region: {:?} at {:?}", region.region_type, region.bbox);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_connected_components() {
|
||||
let img = create_test_image_with_rectangles();
|
||||
let components = connected_components(&img);
|
||||
|
||||
// Check that we have non-zero labels
|
||||
let max_label = components
|
||||
.iter()
|
||||
.flat_map(|row| row.iter())
|
||||
.max()
|
||||
.unwrap_or(&0);
|
||||
|
||||
assert!(*max_label > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_overlapping_regions() {
|
||||
let regions = vec![(10, 10, 50, 20), (40, 10, 50, 20), (100, 100, 30, 30)];
|
||||
|
||||
let merged = merge_overlapping_regions(regions, 10);
|
||||
|
||||
// First two should merge, third stays separate
|
||||
assert_eq!(merged.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_boxes() {
|
||||
let box1 = (10, 10, 50, 20);
|
||||
let box2 = (40, 15, 30, 25);
|
||||
|
||||
let merged = merge_boxes(&box1, &box2);
|
||||
|
||||
assert_eq!(merged.0, 10); // min x
|
||||
assert_eq!(merged.1, 10); // min y
|
||||
assert!(merged.2 >= 50); // width
|
||||
assert!(merged.3 >= 25); // height
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_boxes_overlap() {
|
||||
let box1 = (10, 10, 50, 20);
|
||||
let box2 = (40, 10, 50, 20);
|
||||
|
||||
assert!(boxes_overlap_or_close(&box1, &box2, 0));
|
||||
assert!(boxes_overlap_or_close(&box1, &box2, 10));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_boxes_dont_overlap() {
|
||||
let box1 = (10, 10, 20, 20);
|
||||
let box2 = (100, 100, 20, 20);
|
||||
|
||||
assert!(!boxes_overlap_or_close(&box1, &box2, 0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_text_lines() {
|
||||
let regions = vec![
|
||||
(10, 10, 50, 20),
|
||||
(70, 12, 50, 20),
|
||||
(10, 50, 50, 20),
|
||||
(70, 52, 50, 20),
|
||||
];
|
||||
|
||||
let img = GrayImage::new(200, 100);
|
||||
let lines = find_text_lines(&img, ®ions);
|
||||
|
||||
// Should find 2 lines
|
||||
assert_eq!(lines.len(), 2);
|
||||
assert_eq!(lines[0].len(), 2);
|
||||
assert_eq!(lines[1].len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_calculate_density() {
|
||||
let mut img = GrayImage::new(100, 100);
|
||||
|
||||
// Fill region with 50% black pixels
|
||||
for y in 10..30 {
|
||||
for x in 10..30 {
|
||||
let val = if (x + y) % 2 == 0 { 0 } else { 255 };
|
||||
img.put_pixel(x, y, Luma([val]));
|
||||
}
|
||||
}
|
||||
|
||||
let density = calculate_density(&img, (10, 10, 20, 20));
|
||||
assert!((density - 0.5).abs() < 0.1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_filter_by_size() {
|
||||
let mut bboxes = HashMap::new();
|
||||
bboxes.insert(1, (10, 10, 50, 50)); // 2500 pixels
|
||||
bboxes.insert(2, (100, 100, 10, 10)); // 100 pixels
|
||||
bboxes.insert(3, (200, 200, 30, 30)); // 900 pixels
|
||||
|
||||
let filtered = filter_by_size(bboxes, 500);
|
||||
|
||||
// Should keep regions 1 and 3
|
||||
assert_eq!(filtered.len(), 2);
|
||||
}
|
||||
}
|
||||
400
examples/scipix/src/preprocess/transforms.rs
Normal file
400
examples/scipix/src/preprocess/transforms.rs
Normal file
@@ -0,0 +1,400 @@
|
||||
//! Image transformation functions for preprocessing
|
||||
|
||||
use super::{PreprocessError, Result};
|
||||
use image::{DynamicImage, GrayImage, Luma};
|
||||
use imageproc::filter::gaussian_blur_f32;
|
||||
use std::f32;
|
||||
|
||||
/// Convert image to grayscale
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `image` - Input color or grayscale image
|
||||
///
|
||||
/// # Returns
|
||||
/// Grayscale image
|
||||
pub fn to_grayscale(image: &DynamicImage) -> GrayImage {
|
||||
image.to_luma8()
|
||||
}
|
||||
|
||||
/// Apply Gaussian blur for noise reduction
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `image` - Input grayscale image
|
||||
/// * `sigma` - Standard deviation of Gaussian kernel
|
||||
///
|
||||
/// # Returns
|
||||
/// Blurred image
|
||||
///
|
||||
/// # Example
|
||||
/// ```no_run
|
||||
/// use ruvector_scipix::preprocess::transforms::gaussian_blur;
|
||||
/// # use image::GrayImage;
|
||||
/// # let image = GrayImage::new(100, 100);
|
||||
/// let blurred = gaussian_blur(&image, 1.5).unwrap();
|
||||
/// ```
|
||||
pub fn gaussian_blur(image: &GrayImage, sigma: f32) -> Result<GrayImage> {
|
||||
if sigma <= 0.0 {
|
||||
return Err(PreprocessError::InvalidParameters(
|
||||
"Sigma must be positive".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(gaussian_blur_f32(image, sigma))
|
||||
}
|
||||
|
||||
/// Sharpen image using unsharp mask
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `image` - Input grayscale image
|
||||
/// * `sigma` - Gaussian blur sigma
|
||||
/// * `amount` - Sharpening strength (typically 0.5-2.0)
|
||||
///
|
||||
/// # Returns
|
||||
/// Sharpened image
|
||||
pub fn sharpen(image: &GrayImage, sigma: f32, amount: f32) -> Result<GrayImage> {
|
||||
if sigma <= 0.0 || amount < 0.0 {
|
||||
return Err(PreprocessError::InvalidParameters(
|
||||
"Invalid sharpening parameters".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let blurred = gaussian_blur_f32(image, sigma);
|
||||
let (width, height) = image.dimensions();
|
||||
let mut result = GrayImage::new(width, height);
|
||||
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
let original = image.get_pixel(x, y)[0] as f32;
|
||||
let blur = blurred.get_pixel(x, y)[0] as f32;
|
||||
|
||||
// Unsharp mask: original + amount * (original - blurred)
|
||||
let sharpened = original + amount * (original - blur);
|
||||
let clamped = sharpened.clamp(0.0, 255.0) as u8;
|
||||
|
||||
result.put_pixel(x, y, Luma([clamped]));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Calculate optimal threshold using Otsu's method
|
||||
///
|
||||
/// Implements full Otsu's algorithm for automatic threshold selection
|
||||
/// based on maximizing inter-class variance.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `image` - Input grayscale image
|
||||
///
|
||||
/// # Returns
|
||||
/// Optimal threshold value (0-255)
|
||||
///
|
||||
/// # Example
|
||||
/// ```no_run
|
||||
/// use ruvector_scipix::preprocess::transforms::otsu_threshold;
|
||||
/// # use image::GrayImage;
|
||||
/// # let image = GrayImage::new(100, 100);
|
||||
/// let threshold = otsu_threshold(&image).unwrap();
|
||||
/// println!("Optimal threshold: {}", threshold);
|
||||
/// ```
|
||||
pub fn otsu_threshold(image: &GrayImage) -> Result<u8> {
|
||||
// Calculate histogram
|
||||
let mut histogram = [0u32; 256];
|
||||
for pixel in image.pixels() {
|
||||
histogram[pixel[0] as usize] += 1;
|
||||
}
|
||||
|
||||
let total_pixels = (image.width() * image.height()) as f64;
|
||||
|
||||
// Calculate cumulative sums
|
||||
let mut sum_total = 0.0;
|
||||
for (i, &count) in histogram.iter().enumerate() {
|
||||
sum_total += (i as f64) * (count as f64);
|
||||
}
|
||||
|
||||
let mut sum_background = 0.0;
|
||||
let mut weight_background = 0.0;
|
||||
let mut max_variance = 0.0;
|
||||
let mut threshold = 0u8;
|
||||
|
||||
// Find threshold that maximizes inter-class variance
|
||||
for (t, &count) in histogram.iter().enumerate() {
|
||||
weight_background += count as f64;
|
||||
if weight_background == 0.0 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let weight_foreground = total_pixels - weight_background;
|
||||
if weight_foreground == 0.0 {
|
||||
break;
|
||||
}
|
||||
|
||||
sum_background += (t as f64) * (count as f64);
|
||||
|
||||
let mean_background = sum_background / weight_background;
|
||||
let mean_foreground = (sum_total - sum_background) / weight_foreground;
|
||||
|
||||
// Inter-class variance
|
||||
let variance =
|
||||
weight_background * weight_foreground * (mean_background - mean_foreground).powi(2);
|
||||
|
||||
if variance > max_variance {
|
||||
max_variance = variance;
|
||||
threshold = t as u8;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(threshold)
|
||||
}
|
||||
|
||||
/// Apply binary thresholding
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `image` - Input grayscale image
|
||||
/// * `threshold` - Threshold value (0-255)
|
||||
///
|
||||
/// # Returns
|
||||
/// Binary image (0 or 255)
|
||||
pub fn threshold(image: &GrayImage, threshold_val: u8) -> GrayImage {
|
||||
let (width, height) = image.dimensions();
|
||||
let mut result = GrayImage::new(width, height);
|
||||
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
let pixel = image.get_pixel(x, y)[0];
|
||||
let value = if pixel >= threshold_val { 255 } else { 0 };
|
||||
result.put_pixel(x, y, Luma([value]));
|
||||
}
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
/// Apply adaptive thresholding using local window statistics
|
||||
///
|
||||
/// Uses a sliding window to calculate local mean and applies threshold
|
||||
/// relative to local statistics. Better for images with varying illumination.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `image` - Input grayscale image
|
||||
/// * `window_size` - Size of local window (must be odd)
|
||||
///
|
||||
/// # Returns
|
||||
/// Binary image with adaptive thresholding applied
|
||||
///
|
||||
/// # Example
|
||||
/// ```no_run
|
||||
/// use ruvector_scipix::preprocess::transforms::adaptive_threshold;
|
||||
/// # use image::GrayImage;
|
||||
/// # let image = GrayImage::new(100, 100);
|
||||
/// let binary = adaptive_threshold(&image, 15).unwrap();
|
||||
/// ```
|
||||
pub fn adaptive_threshold(image: &GrayImage, window_size: u32) -> Result<GrayImage> {
|
||||
if window_size % 2 == 0 {
|
||||
return Err(PreprocessError::InvalidParameters(
|
||||
"Window size must be odd".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let (width, height) = image.dimensions();
|
||||
let mut result = GrayImage::new(width, height);
|
||||
let half_window = (window_size / 2) as i32;
|
||||
|
||||
// Use integral image for fast window sum calculation
|
||||
let integral = compute_integral_image(image);
|
||||
|
||||
for y in 0..height as i32 {
|
||||
for x in 0..width as i32 {
|
||||
// Define window bounds
|
||||
let x1 = (x - half_window).max(0);
|
||||
let y1 = (y - half_window).max(0);
|
||||
let x2 = (x + half_window + 1).min(width as i32);
|
||||
let y2 = (y + half_window + 1).min(height as i32);
|
||||
|
||||
// Calculate mean using integral image
|
||||
let area = ((x2 - x1) * (y2 - y1)) as f64;
|
||||
let sum = get_integral_sum(&integral, x1, y1, x2, y2);
|
||||
let mean = (sum as f64 / area) as u8;
|
||||
|
||||
// Apply threshold with small bias
|
||||
let pixel = image.get_pixel(x as u32, y as u32)[0];
|
||||
let bias = 5; // Small bias to reduce noise
|
||||
let value = if pixel >= mean.saturating_sub(bias) {
|
||||
255
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
result.put_pixel(x as u32, y as u32, Luma([value]));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Compute integral image for fast rectangle sum queries
|
||||
fn compute_integral_image(image: &GrayImage) -> Vec<Vec<u64>> {
|
||||
let (width, height) = image.dimensions();
|
||||
let mut integral = vec![vec![0u64; width as usize + 1]; height as usize + 1];
|
||||
|
||||
for y in 1..=height as usize {
|
||||
for x in 1..=width as usize {
|
||||
let pixel = image.get_pixel(x as u32 - 1, y as u32 - 1)[0] as u64;
|
||||
integral[y][x] =
|
||||
pixel + integral[y - 1][x] + integral[y][x - 1] - integral[y - 1][x - 1];
|
||||
}
|
||||
}
|
||||
|
||||
integral
|
||||
}
|
||||
|
||||
/// Get sum of rectangle in integral image
|
||||
fn get_integral_sum(integral: &[Vec<u64>], x1: i32, y1: i32, x2: i32, y2: i32) -> u64 {
|
||||
let x1 = x1 as usize;
|
||||
let y1 = y1 as usize;
|
||||
let x2 = x2 as usize;
|
||||
let y2 = y2 as usize;
|
||||
|
||||
integral[y2][x2] + integral[y1][x1] - integral[y1][x2] - integral[y2][x1]
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use approx::assert_relative_eq;
|
||||
|
||||
fn create_gradient_image(width: u32, height: u32) -> GrayImage {
|
||||
let mut img = GrayImage::new(width, height);
|
||||
for y in 0..height {
|
||||
for x in 0..width {
|
||||
let val = ((x + y) * 255 / (width + height)) as u8;
|
||||
img.put_pixel(x, y, Luma([val]));
|
||||
}
|
||||
}
|
||||
img
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_grayscale() {
|
||||
let img = DynamicImage::new_rgb8(100, 100);
|
||||
let gray = to_grayscale(&img);
|
||||
assert_eq!(gray.dimensions(), (100, 100));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gaussian_blur() {
|
||||
let img = create_gradient_image(50, 50);
|
||||
let blurred = gaussian_blur(&img, 1.0);
|
||||
assert!(blurred.is_ok());
|
||||
|
||||
let result = blurred.unwrap();
|
||||
assert_eq!(result.dimensions(), img.dimensions());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gaussian_blur_invalid_sigma() {
|
||||
let img = create_gradient_image(50, 50);
|
||||
let result = gaussian_blur(&img, -1.0);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sharpen() {
|
||||
let img = create_gradient_image(50, 50);
|
||||
let sharpened = sharpen(&img, 1.0, 1.5);
|
||||
assert!(sharpened.is_ok());
|
||||
|
||||
let result = sharpened.unwrap();
|
||||
assert_eq!(result.dimensions(), img.dimensions());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_otsu_threshold() {
|
||||
// Create bimodal image (good for Otsu)
|
||||
let mut img = GrayImage::new(100, 100);
|
||||
for y in 0..100 {
|
||||
for x in 0..100 {
|
||||
let val = if x < 50 { 50 } else { 200 };
|
||||
img.put_pixel(x, y, Luma([val]));
|
||||
}
|
||||
}
|
||||
|
||||
let threshold = otsu_threshold(&img);
|
||||
assert!(threshold.is_ok());
|
||||
|
||||
let t = threshold.unwrap();
|
||||
// Should be somewhere between the two values (not necessarily strictly between)
|
||||
// Otsu finds optimal threshold which could be at boundary
|
||||
assert!(
|
||||
t >= 50 && t <= 200,
|
||||
"threshold {} should be between 50 and 200",
|
||||
t
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_threshold() {
|
||||
let img = create_gradient_image(100, 100);
|
||||
let binary = threshold(&img, 128);
|
||||
|
||||
assert_eq!(binary.dimensions(), img.dimensions());
|
||||
|
||||
// Check that output is binary
|
||||
for pixel in binary.pixels() {
|
||||
let val = pixel[0];
|
||||
assert!(val == 0 || val == 255);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adaptive_threshold() {
|
||||
let img = create_gradient_image(100, 100);
|
||||
let binary = adaptive_threshold(&img, 15);
|
||||
assert!(binary.is_ok());
|
||||
|
||||
let result = binary.unwrap();
|
||||
assert_eq!(result.dimensions(), img.dimensions());
|
||||
|
||||
// Check binary output
|
||||
for pixel in result.pixels() {
|
||||
let val = pixel[0];
|
||||
assert!(val == 0 || val == 255);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adaptive_threshold_invalid_window() {
|
||||
let img = create_gradient_image(50, 50);
|
||||
let result = adaptive_threshold(&img, 16); // Even number
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_integral_image() {
|
||||
let mut img = GrayImage::new(3, 3);
|
||||
for y in 0..3 {
|
||||
for x in 0..3 {
|
||||
img.put_pixel(x, y, Luma([1]));
|
||||
}
|
||||
}
|
||||
|
||||
let integral = compute_integral_image(&img);
|
||||
|
||||
// Check 3x3 sum
|
||||
let sum = get_integral_sum(&integral, 0, 0, 3, 3);
|
||||
assert_eq!(sum, 9); // 3x3 image with all 1s
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_threshold_extremes() {
|
||||
let img = create_gradient_image(100, 100);
|
||||
|
||||
// Threshold at 0 should make everything white
|
||||
let binary = threshold(&img, 0);
|
||||
assert!(binary.pixels().all(|p| p[0] == 255));
|
||||
|
||||
// Threshold at 255 should make everything black
|
||||
let binary = threshold(&img, 255);
|
||||
assert!(binary.pixels().all(|p| p[0] == 0));
|
||||
}
|
||||
}
|
||||
189
examples/scipix/src/wasm/api.rs
Normal file
189
examples/scipix/src/wasm/api.rs
Normal file
@@ -0,0 +1,189 @@
|
||||
//! JavaScript API for Scipix OCR
|
||||
|
||||
use once_cell::sync::OnceCell;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use wasm_bindgen::prelude::*;
|
||||
use web_sys::{HtmlCanvasElement, ImageData};
|
||||
|
||||
use crate::wasm::canvas::CanvasProcessor;
|
||||
use crate::wasm::memory::WasmBuffer;
|
||||
use crate::wasm::types::{OcrResult, RecognitionFormat};
|
||||
|
||||
static PROCESSOR: OnceCell<Arc<CanvasProcessor>> = OnceCell::new();
|
||||
|
||||
/// Main WASM API for Scipix OCR
|
||||
#[wasm_bindgen]
|
||||
pub struct ScipixWasm {
|
||||
processor: Arc<CanvasProcessor>,
|
||||
format: RecognitionFormat,
|
||||
confidence_threshold: f32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl ScipixWasm {
|
||||
/// Create a new ScipixWasm instance
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub async fn new() -> Result<ScipixWasm, JsValue> {
|
||||
let processor = PROCESSOR
|
||||
.get_or_init(|| Arc::new(CanvasProcessor::new()))
|
||||
.clone();
|
||||
|
||||
Ok(ScipixWasm {
|
||||
processor,
|
||||
format: RecognitionFormat::Both,
|
||||
confidence_threshold: 0.5,
|
||||
})
|
||||
}
|
||||
|
||||
/// Recognize text from raw image data
|
||||
#[wasm_bindgen]
|
||||
pub async fn recognize(&self, image_data: &[u8]) -> Result<JsValue, JsValue> {
|
||||
let buffer = WasmBuffer::from_slice(image_data);
|
||||
|
||||
let result = self
|
||||
.processor
|
||||
.process_image_bytes(buffer.as_slice(), self.format)
|
||||
.await
|
||||
.map_err(|e| JsValue::from_str(&format!("Recognition failed: {}", e)))?;
|
||||
|
||||
// Filter by confidence threshold
|
||||
let filtered = self.filter_by_confidence(result);
|
||||
|
||||
serde_wasm_bindgen::to_value(&filtered)
|
||||
.map_err(|e| JsValue::from_str(&format!("Serialization failed: {}", e)))
|
||||
}
|
||||
|
||||
/// Recognize text from HTML Canvas element
|
||||
#[wasm_bindgen(js_name = recognizeFromCanvas)]
|
||||
pub async fn recognize_from_canvas(
|
||||
&self,
|
||||
canvas: &HtmlCanvasElement,
|
||||
) -> Result<JsValue, JsValue> {
|
||||
let image_data = self
|
||||
.processor
|
||||
.extract_canvas_image(canvas)
|
||||
.map_err(|e| JsValue::from_str(&format!("Canvas extraction failed: {}", e)))?;
|
||||
|
||||
let result = self
|
||||
.processor
|
||||
.process_image_data(&image_data, self.format)
|
||||
.await
|
||||
.map_err(|e| JsValue::from_str(&format!("Recognition failed: {}", e)))?;
|
||||
|
||||
let filtered = self.filter_by_confidence(result);
|
||||
|
||||
serde_wasm_bindgen::to_value(&filtered)
|
||||
.map_err(|e| JsValue::from_str(&format!("Serialization failed: {}", e)))
|
||||
}
|
||||
|
||||
/// Recognize text from base64-encoded image
|
||||
#[wasm_bindgen(js_name = recognizeBase64)]
|
||||
pub async fn recognize_base64(&self, base64: &str) -> Result<JsValue, JsValue> {
|
||||
// Remove data URL prefix if present
|
||||
let base64_data = if base64.contains(',') {
|
||||
base64.split(',').nth(1).unwrap_or(base64)
|
||||
} else {
|
||||
base64
|
||||
};
|
||||
|
||||
let image_bytes = base64::decode(base64_data)
|
||||
.map_err(|e| JsValue::from_str(&format!("Base64 decode failed: {}", e)))?;
|
||||
|
||||
self.recognize(&image_bytes).await
|
||||
}
|
||||
|
||||
/// Recognize text from ImageData object
|
||||
#[wasm_bindgen(js_name = recognizeImageData)]
|
||||
pub async fn recognize_image_data(&self, image_data: &ImageData) -> Result<JsValue, JsValue> {
|
||||
let result = self
|
||||
.processor
|
||||
.process_image_data(image_data, self.format)
|
||||
.await
|
||||
.map_err(|e| JsValue::from_str(&format!("Recognition failed: {}", e)))?;
|
||||
|
||||
let filtered = self.filter_by_confidence(result);
|
||||
|
||||
serde_wasm_bindgen::to_value(&filtered)
|
||||
.map_err(|e| JsValue::from_str(&format!("Serialization failed: {}", e)))
|
||||
}
|
||||
|
||||
/// Set the output format (text, latex, or both)
|
||||
#[wasm_bindgen(js_name = setFormat)]
|
||||
pub fn set_format(&mut self, format: &str) {
|
||||
self.format = match format.to_lowercase().as_str() {
|
||||
"text" => RecognitionFormat::Text,
|
||||
"latex" => RecognitionFormat::Latex,
|
||||
"both" => RecognitionFormat::Both,
|
||||
_ => RecognitionFormat::Both,
|
||||
};
|
||||
}
|
||||
|
||||
/// Set the confidence threshold (0.0 - 1.0)
|
||||
#[wasm_bindgen(js_name = setConfidenceThreshold)]
|
||||
pub fn set_confidence_threshold(&mut self, threshold: f32) {
|
||||
self.confidence_threshold = threshold.clamp(0.0, 1.0);
|
||||
}
|
||||
|
||||
/// Get the current confidence threshold
|
||||
#[wasm_bindgen(js_name = getConfidenceThreshold)]
|
||||
pub fn get_confidence_threshold(&self) -> f32 {
|
||||
self.confidence_threshold
|
||||
}
|
||||
|
||||
/// Get the version of the library
|
||||
#[wasm_bindgen(js_name = getVersion)]
|
||||
pub fn get_version(&self) -> String {
|
||||
env!("CARGO_PKG_VERSION").to_string()
|
||||
}
|
||||
|
||||
/// Get supported output formats
|
||||
#[wasm_bindgen(js_name = getSupportedFormats)]
|
||||
pub fn get_supported_formats(&self) -> Vec<JsValue> {
|
||||
vec![
|
||||
JsValue::from_str("text"),
|
||||
JsValue::from_str("latex"),
|
||||
JsValue::from_str("both"),
|
||||
]
|
||||
}
|
||||
|
||||
/// Batch process multiple images
|
||||
#[wasm_bindgen(js_name = recognizeBatch)]
|
||||
pub async fn recognize_batch(&self, images: Vec<JsValue>) -> Result<JsValue, JsValue> {
|
||||
let mut results = Vec::new();
|
||||
|
||||
for img in images {
|
||||
// Try to process as Uint8Array
|
||||
if let Ok(bytes) = js_sys::Uint8Array::new(&img).to_vec() {
|
||||
match self.recognize(&bytes).await {
|
||||
Ok(result) => results.push(result),
|
||||
Err(e) => {
|
||||
web_sys::console::warn_1(&JsValue::from_str(&format!(
|
||||
"Failed to process image: {:?}",
|
||||
e
|
||||
)));
|
||||
results.push(JsValue::NULL);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(js_sys::Array::from_iter(results).into())
|
||||
}
|
||||
|
||||
// Private helper methods
|
||||
|
||||
fn filter_by_confidence(&self, mut result: OcrResult) -> OcrResult {
|
||||
if result.confidence < self.confidence_threshold {
|
||||
result.text = String::new();
|
||||
result.latex = None;
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new ScipixWasm instance (factory function)
|
||||
#[wasm_bindgen(js_name = createScipix)]
|
||||
pub async fn create_scipix() -> Result<ScipixWasm, JsValue> {
|
||||
ScipixWasm::new().await
|
||||
}
|
||||
217
examples/scipix/src/wasm/canvas.rs
Normal file
217
examples/scipix/src/wasm/canvas.rs
Normal file
@@ -0,0 +1,217 @@
|
||||
//! Canvas and ImageData handling for WASM
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use image::{DynamicImage, ImageBuffer, Rgba};
|
||||
use wasm_bindgen::prelude::*;
|
||||
use web_sys::{CanvasRenderingContext2d, HtmlCanvasElement, ImageData};
|
||||
|
||||
use crate::wasm::types::{OcrResult, RecognitionFormat};
|
||||
|
||||
/// Processor for canvas and image data
|
||||
pub struct CanvasProcessor {
|
||||
// Could add model loading here in the future
|
||||
}
|
||||
|
||||
impl CanvasProcessor {
|
||||
/// Create a new canvas processor
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
|
||||
/// Extract image data from HTML canvas element
|
||||
pub fn extract_canvas_image(&self, canvas: &HtmlCanvasElement) -> Result<ImageData> {
|
||||
let context = canvas
|
||||
.get_context("2d")
|
||||
.map_err(|_| anyhow!("Failed to get 2d context"))?
|
||||
.ok_or_else(|| anyhow!("Context is null"))?
|
||||
.dyn_into::<CanvasRenderingContext2d>()
|
||||
.map_err(|_| anyhow!("Failed to cast to 2d context"))?;
|
||||
|
||||
let width = canvas.width();
|
||||
let height = canvas.height();
|
||||
|
||||
context
|
||||
.get_image_data(0.0, 0.0, width as f64, height as f64)
|
||||
.map_err(|_| anyhow!("Failed to get image data"))
|
||||
}
|
||||
|
||||
/// Convert ImageData to DynamicImage
|
||||
pub fn image_data_to_dynamic(&self, image_data: &ImageData) -> Result<DynamicImage> {
|
||||
let width = image_data.width();
|
||||
let height = image_data.height();
|
||||
let data = image_data.data();
|
||||
|
||||
let img_buffer = ImageBuffer::<Rgba<u8>, Vec<u8>>::from_raw(width, height, data.to_vec())
|
||||
.ok_or_else(|| anyhow!("Failed to create image buffer"))?;
|
||||
|
||||
Ok(DynamicImage::ImageRgba8(img_buffer))
|
||||
}
|
||||
|
||||
/// Process raw image bytes
|
||||
pub async fn process_image_bytes(
|
||||
&self,
|
||||
image_bytes: &[u8],
|
||||
format: RecognitionFormat,
|
||||
) -> Result<OcrResult> {
|
||||
// Decode image
|
||||
let img = image::load_from_memory(image_bytes)
|
||||
.map_err(|e| anyhow!("Failed to decode image: {}", e))?;
|
||||
|
||||
self.process_dynamic_image(&img, format).await
|
||||
}
|
||||
|
||||
/// Process ImageData from canvas
|
||||
pub async fn process_image_data(
|
||||
&self,
|
||||
image_data: &ImageData,
|
||||
format: RecognitionFormat,
|
||||
) -> Result<OcrResult> {
|
||||
let img = self.image_data_to_dynamic(image_data)?;
|
||||
self.process_dynamic_image(&img, format).await
|
||||
}
|
||||
|
||||
/// Process a DynamicImage
|
||||
async fn process_dynamic_image(
|
||||
&self,
|
||||
img: &DynamicImage,
|
||||
format: RecognitionFormat,
|
||||
) -> Result<OcrResult> {
|
||||
// Convert to grayscale for processing
|
||||
let gray = img.to_luma8();
|
||||
|
||||
// Apply preprocessing
|
||||
let preprocessed = self.preprocess_image(&gray);
|
||||
|
||||
// Perform OCR (mock implementation for now)
|
||||
// In a real implementation, this would run a model
|
||||
let text = self.extract_text(&preprocessed)?;
|
||||
let latex = if matches!(format, RecognitionFormat::Latex | RecognitionFormat::Both) {
|
||||
Some(self.extract_latex(&preprocessed)?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Calculate confidence (simplified)
|
||||
let confidence = self.calculate_confidence(&text, &latex);
|
||||
|
||||
Ok(OcrResult {
|
||||
text,
|
||||
latex,
|
||||
confidence,
|
||||
metadata: Some(serde_json::json!({
|
||||
"width": img.width(),
|
||||
"height": img.height(),
|
||||
"format": format.to_string(),
|
||||
})),
|
||||
})
|
||||
}
|
||||
|
||||
/// Preprocess image for OCR
|
||||
fn preprocess_image(&self, img: &image::GrayImage) -> image::GrayImage {
|
||||
// Apply simple thresholding
|
||||
let mut output = img.clone();
|
||||
|
||||
for pixel in output.pixels_mut() {
|
||||
let value = pixel.0[0];
|
||||
pixel.0[0] = if value > 128 { 255 } else { 0 };
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Extract plain text (mock implementation)
|
||||
fn extract_text(&self, img: &image::GrayImage) -> Result<String> {
|
||||
// This would normally run an OCR model
|
||||
// For now, return a placeholder
|
||||
Ok("Recognized text placeholder".to_string())
|
||||
}
|
||||
|
||||
/// Extract LaTeX (mock implementation)
|
||||
fn extract_latex(&self, img: &image::GrayImage) -> Result<String> {
|
||||
// This would normally run a math OCR model
|
||||
// For now, return a placeholder
|
||||
Ok(r"\sum_{i=1}^{n} x_i".to_string())
|
||||
}
|
||||
|
||||
/// Calculate confidence score
|
||||
fn calculate_confidence(&self, text: &str, latex: &Option<String>) -> f32 {
|
||||
// Simple heuristic: longer text = higher confidence
|
||||
let text_score = (text.len() as f32 / 100.0).min(1.0);
|
||||
let latex_score = latex
|
||||
.as_ref()
|
||||
.map(|l| (l.len() as f32 / 50.0).min(1.0))
|
||||
.unwrap_or(0.0);
|
||||
|
||||
(text_score + latex_score) / 2.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CanvasProcessor {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert blob URL to image data
|
||||
#[wasm_bindgen]
|
||||
pub async fn blob_url_to_image_data(blob_url: &str) -> Result<ImageData, JsValue> {
|
||||
use web_sys::{window, HtmlImageElement};
|
||||
|
||||
let window = window().ok_or_else(|| JsValue::from_str("No window"))?;
|
||||
let document = window
|
||||
.document()
|
||||
.ok_or_else(|| JsValue::from_str("No document"))?;
|
||||
|
||||
// Create image element
|
||||
let img =
|
||||
HtmlImageElement::new().map_err(|_| JsValue::from_str("Failed to create image element"))?;
|
||||
|
||||
img.set_src(blob_url);
|
||||
|
||||
// Wait for image to load
|
||||
let promise = js_sys::Promise::new(&mut |resolve, reject| {
|
||||
let img_clone = img.clone();
|
||||
let onload = Closure::wrap(Box::new(move || {
|
||||
resolve.call1(&JsValue::NULL, &img_clone).unwrap();
|
||||
}) as Box<dyn FnMut()>);
|
||||
|
||||
img.set_onload(Some(onload.as_ref().unchecked_ref()));
|
||||
onload.forget();
|
||||
|
||||
let onerror = Closure::wrap(Box::new(move || {
|
||||
reject
|
||||
.call1(&JsValue::NULL, &JsValue::from_str("Image load failed"))
|
||||
.unwrap();
|
||||
}) as Box<dyn FnMut()>);
|
||||
|
||||
img.set_onerror(Some(onerror.as_ref().unchecked_ref()));
|
||||
onerror.forget();
|
||||
});
|
||||
|
||||
wasm_bindgen_futures::JsFuture::from(promise).await?;
|
||||
|
||||
// Create canvas and draw image
|
||||
let canvas = document
|
||||
.create_element("canvas")
|
||||
.map_err(|_| JsValue::from_str("Failed to create canvas"))?
|
||||
.dyn_into::<HtmlCanvasElement>()
|
||||
.map_err(|_| JsValue::from_str("Failed to cast to canvas"))?;
|
||||
|
||||
canvas.set_width(img.natural_width());
|
||||
canvas.set_height(img.natural_height());
|
||||
|
||||
let context = canvas
|
||||
.get_context("2d")
|
||||
.map_err(|_| JsValue::from_str("Failed to get 2d context"))?
|
||||
.ok_or_else(|| JsValue::from_str("Context is null"))?
|
||||
.dyn_into::<CanvasRenderingContext2d>()
|
||||
.map_err(|_| JsValue::from_str("Failed to cast to 2d context"))?;
|
||||
|
||||
context
|
||||
.draw_image_with_html_image_element(&img, 0.0, 0.0)
|
||||
.map_err(|_| JsValue::from_str("Failed to draw image"))?;
|
||||
|
||||
context
|
||||
.get_image_data(0.0, 0.0, canvas.width() as f64, canvas.height() as f64)
|
||||
.map_err(|_| JsValue::from_str("Failed to get image data"))
|
||||
}
|
||||
218
examples/scipix/src/wasm/memory.rs
Normal file
218
examples/scipix/src/wasm/memory.rs
Normal file
@@ -0,0 +1,218 @@
|
||||
//! Memory management for WASM
|
||||
|
||||
use std::ops::Deref;
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Efficient buffer wrapper for WASM memory management
|
||||
pub struct WasmBuffer {
|
||||
data: Vec<u8>,
|
||||
}
|
||||
|
||||
impl WasmBuffer {
|
||||
/// Create a new buffer with capacity
|
||||
pub fn with_capacity(capacity: usize) -> Self {
|
||||
Self {
|
||||
data: Vec::with_capacity(capacity),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create buffer from slice (copies data)
|
||||
pub fn from_slice(slice: &[u8]) -> Self {
|
||||
Self {
|
||||
data: slice.to_vec(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create buffer from Vec (takes ownership)
|
||||
pub fn from_vec(data: Vec<u8>) -> Self {
|
||||
Self { data }
|
||||
}
|
||||
|
||||
/// Get the underlying slice
|
||||
pub fn as_slice(&self) -> &[u8] {
|
||||
&self.data
|
||||
}
|
||||
|
||||
/// Get mutable slice
|
||||
pub fn as_mut_slice(&mut self) -> &mut [u8] {
|
||||
&mut self.data
|
||||
}
|
||||
|
||||
/// Get length
|
||||
pub fn len(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
|
||||
/// Check if empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.data.is_empty()
|
||||
}
|
||||
|
||||
/// Clear the buffer (keeps capacity)
|
||||
pub fn clear(&mut self) {
|
||||
self.data.clear();
|
||||
}
|
||||
|
||||
/// Shrink to fit
|
||||
pub fn shrink_to_fit(&mut self) {
|
||||
self.data.shrink_to_fit();
|
||||
}
|
||||
|
||||
/// Convert to Vec
|
||||
pub fn into_vec(self) -> Vec<u8> {
|
||||
self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for WasmBuffer {
|
||||
type Target = [u8];
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.data
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for WasmBuffer {
|
||||
fn drop(&mut self) {
|
||||
// Explicitly clear to help WASM memory management
|
||||
self.data.clear();
|
||||
self.data.shrink_to_fit();
|
||||
}
|
||||
}
|
||||
|
||||
/// Shared memory for large images (uses SharedArrayBuffer when available)
|
||||
#[wasm_bindgen]
|
||||
pub struct SharedImageBuffer {
|
||||
buffer: WasmBuffer,
|
||||
width: u32,
|
||||
height: u32,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl SharedImageBuffer {
|
||||
/// Create a new shared buffer
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(width: u32, height: u32) -> Self {
|
||||
let size = (width * height * 4) as usize; // RGBA
|
||||
Self {
|
||||
buffer: WasmBuffer::with_capacity(size),
|
||||
width,
|
||||
height,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get width
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn width(&self) -> u32 {
|
||||
self.width
|
||||
}
|
||||
|
||||
/// Get height
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn height(&self) -> u32 {
|
||||
self.height
|
||||
}
|
||||
|
||||
/// Get buffer size
|
||||
#[wasm_bindgen(js_name = bufferSize)]
|
||||
pub fn buffer_size(&self) -> usize {
|
||||
self.buffer.len()
|
||||
}
|
||||
|
||||
/// Get buffer as Uint8Array
|
||||
#[wasm_bindgen(js_name = getBuffer)]
|
||||
pub fn get_buffer(&self) -> js_sys::Uint8Array {
|
||||
js_sys::Uint8Array::from(self.buffer.as_slice())
|
||||
}
|
||||
|
||||
/// Set buffer from Uint8Array
|
||||
#[wasm_bindgen(js_name = setBuffer)]
|
||||
pub fn set_buffer(&mut self, data: &js_sys::Uint8Array) {
|
||||
self.buffer = WasmBuffer::from_vec(data.to_vec());
|
||||
}
|
||||
|
||||
/// Clear the buffer
|
||||
pub fn clear(&mut self) {
|
||||
self.buffer.clear();
|
||||
}
|
||||
}
|
||||
|
||||
/// Memory pool for reusing buffers
|
||||
pub struct MemoryPool {
|
||||
buffers: Vec<WasmBuffer>,
|
||||
max_size: usize,
|
||||
}
|
||||
|
||||
impl MemoryPool {
|
||||
/// Create a new memory pool
|
||||
pub fn new(max_size: usize) -> Self {
|
||||
Self {
|
||||
buffers: Vec::with_capacity(max_size),
|
||||
max_size,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a buffer from the pool or create a new one
|
||||
pub fn acquire(&mut self, size: usize) -> WasmBuffer {
|
||||
self.buffers
|
||||
.pop()
|
||||
.map(|mut buf| {
|
||||
buf.clear();
|
||||
buf
|
||||
})
|
||||
.unwrap_or_else(|| WasmBuffer::with_capacity(size))
|
||||
}
|
||||
|
||||
/// Return a buffer to the pool
|
||||
pub fn release(&mut self, mut buffer: WasmBuffer) {
|
||||
if self.buffers.len() < self.max_size {
|
||||
buffer.clear();
|
||||
self.buffers.push(buffer);
|
||||
}
|
||||
// Otherwise drop the buffer
|
||||
}
|
||||
|
||||
/// Clear all buffers from the pool
|
||||
pub fn clear(&mut self) {
|
||||
self.buffers.clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MemoryPool {
|
||||
fn default() -> Self {
|
||||
Self::new(10)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get memory usage statistics
|
||||
#[wasm_bindgen(js_name = getMemoryStats)]
|
||||
pub fn get_memory_stats() -> JsValue {
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
{
|
||||
use wasm_bindgen::JsValue;
|
||||
|
||||
// Try to get memory info from performance.memory (non-standard)
|
||||
let performance = web_sys::window().and_then(|w| w.performance());
|
||||
|
||||
if let Some(perf) = performance {
|
||||
serde_wasm_bindgen::to_value(&serde_json::json!({
|
||||
"available": true,
|
||||
"timestamp": perf.now(),
|
||||
}))
|
||||
.unwrap_or(JsValue::NULL)
|
||||
} else {
|
||||
JsValue::NULL
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
JsValue::NULL
|
||||
}
|
||||
|
||||
/// Force garbage collection (hint to runtime)
|
||||
#[wasm_bindgen(js_name = forceGC)]
|
||||
pub fn force_gc() {
|
||||
// This is just a hint; actual GC is controlled by the JS runtime
|
||||
// In wasm-bindgen, we can't directly trigger GC
|
||||
// But we can help by ensuring our memory is freed
|
||||
}
|
||||
49
examples/scipix/src/wasm/mod.rs
Normal file
49
examples/scipix/src/wasm/mod.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
//! WebAssembly bindings for Scipix OCR
|
||||
//!
|
||||
//! This module provides WASM bindings with wasm-bindgen for browser-based OCR.
|
||||
|
||||
#![cfg(target_arch = "wasm32")]
|
||||
|
||||
pub mod api;
|
||||
pub mod canvas;
|
||||
pub mod memory;
|
||||
pub mod types;
|
||||
pub mod worker;
|
||||
|
||||
pub use api::ScipixWasm;
|
||||
pub use types::*;
|
||||
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// Initialize the WASM module with panic hooks and allocator
|
||||
#[wasm_bindgen(start)]
|
||||
pub fn init() {
|
||||
// Set panic hook for better error messages
|
||||
#[cfg(feature = "console_error_panic_hook")]
|
||||
console_error_panic_hook::set_once();
|
||||
|
||||
// Use wee_alloc for smaller binary size
|
||||
#[cfg(feature = "wee_alloc")]
|
||||
{
|
||||
#[global_allocator]
|
||||
static ALLOC: wee_alloc::WeeAlloc = wee_alloc::WeeAlloc::INIT;
|
||||
}
|
||||
|
||||
// Initialize logging
|
||||
tracing_wasm::set_as_global_default();
|
||||
}
|
||||
|
||||
/// Get the version of the WASM module
|
||||
#[wasm_bindgen]
|
||||
pub fn version() -> String {
|
||||
env!("CARGO_PKG_VERSION").to_string()
|
||||
}
|
||||
|
||||
/// Check if the WASM module is ready
|
||||
#[wasm_bindgen]
|
||||
pub fn is_ready() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
// Re-export tracing-wasm for logging
|
||||
use tracing_wasm;
|
||||
179
examples/scipix/src/wasm/types.rs
Normal file
179
examples/scipix/src/wasm/types.rs
Normal file
@@ -0,0 +1,179 @@
|
||||
//! Type definitions for WASM API
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use wasm_bindgen::prelude::*;
|
||||
|
||||
/// OCR result returned to JavaScript
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[wasm_bindgen]
|
||||
pub struct OcrResult {
|
||||
/// Recognized plain text
|
||||
pub text: String,
|
||||
|
||||
/// LaTeX representation (if applicable)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub latex: Option<String>,
|
||||
|
||||
/// Confidence score (0.0 - 1.0)
|
||||
pub confidence: f32,
|
||||
|
||||
/// Additional metadata
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl OcrResult {
|
||||
/// Create a new OCR result
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new(text: String, confidence: f32) -> Self {
|
||||
Self {
|
||||
text,
|
||||
latex: None,
|
||||
confidence,
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the text
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn text(&self) -> String {
|
||||
self.text.clone()
|
||||
}
|
||||
|
||||
/// Get the LaTeX (if available)
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn latex(&self) -> Option<String> {
|
||||
self.latex.clone()
|
||||
}
|
||||
|
||||
/// Get the confidence score
|
||||
#[wasm_bindgen(getter)]
|
||||
pub fn confidence(&self) -> f32 {
|
||||
self.confidence
|
||||
}
|
||||
|
||||
/// Check if result has LaTeX
|
||||
#[wasm_bindgen(js_name = hasLatex)]
|
||||
pub fn has_latex(&self) -> bool {
|
||||
self.latex.is_some()
|
||||
}
|
||||
|
||||
/// Convert to JSON
|
||||
#[wasm_bindgen(js_name = toJSON)]
|
||||
pub fn to_json(&self) -> Result<JsValue, JsValue> {
|
||||
serde_wasm_bindgen::to_value(self)
|
||||
.map_err(|e| JsValue::from_str(&format!("Serialization failed: {}", e)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Recognition output format
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum RecognitionFormat {
|
||||
/// Plain text only
|
||||
Text,
|
||||
/// LaTeX only
|
||||
Latex,
|
||||
/// Both text and LaTeX
|
||||
Both,
|
||||
}
|
||||
|
||||
impl RecognitionFormat {
|
||||
pub fn to_string(&self) -> String {
|
||||
match self {
|
||||
Self::Text => "text".to_string(),
|
||||
Self::Latex => "latex".to_string(),
|
||||
Self::Both => "both".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RecognitionFormat {
|
||||
fn default() -> Self {
|
||||
Self::Both
|
||||
}
|
||||
}
|
||||
|
||||
/// Processing options
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[wasm_bindgen]
|
||||
pub struct ProcessingOptions {
|
||||
/// Output format
|
||||
pub format: String,
|
||||
|
||||
/// Confidence threshold
|
||||
pub confidence_threshold: f32,
|
||||
|
||||
/// Enable preprocessing
|
||||
pub preprocess: bool,
|
||||
|
||||
/// Enable postprocessing
|
||||
pub postprocess: bool,
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
impl ProcessingOptions {
|
||||
/// Create default options
|
||||
#[wasm_bindgen(constructor)]
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Set format
|
||||
#[wasm_bindgen(js_name = setFormat)]
|
||||
pub fn set_format(&mut self, format: String) {
|
||||
self.format = format;
|
||||
}
|
||||
|
||||
/// Set confidence threshold
|
||||
#[wasm_bindgen(js_name = setConfidenceThreshold)]
|
||||
pub fn set_confidence_threshold(&mut self, threshold: f32) {
|
||||
self.confidence_threshold = threshold;
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ProcessingOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
format: "both".to_string(),
|
||||
confidence_threshold: 0.5,
|
||||
preprocess: true,
|
||||
postprocess: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Error types for WASM API
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum WasmError {
|
||||
/// Image decoding error
|
||||
ImageDecode(String),
|
||||
|
||||
/// Processing error
|
||||
Processing(String),
|
||||
|
||||
/// Invalid input
|
||||
InvalidInput(String),
|
||||
|
||||
/// Not initialized
|
||||
NotInitialized,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for WasmError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::ImageDecode(msg) => write!(f, "Image decode error: {}", msg),
|
||||
Self::Processing(msg) => write!(f, "Processing error: {}", msg),
|
||||
Self::InvalidInput(msg) => write!(f, "Invalid input: {}", msg),
|
||||
Self::NotInitialized => write!(f, "WASM module not initialized"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for WasmError {}
|
||||
|
||||
impl From<WasmError> for JsValue {
|
||||
fn from(error: WasmError) -> Self {
|
||||
JsValue::from_str(&error.to_string())
|
||||
}
|
||||
}
|
||||
243
examples/scipix/src/wasm/worker.rs
Normal file
243
examples/scipix/src/wasm/worker.rs
Normal file
@@ -0,0 +1,243 @@
|
||||
//! Web Worker support for off-main-thread OCR processing
|
||||
|
||||
use once_cell::sync::OnceCell;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use wasm_bindgen::prelude::*;
|
||||
use web_sys::{DedicatedWorkerGlobalScope, MessageEvent};
|
||||
|
||||
use crate::wasm::api::ScipixWasm;
|
||||
use crate::wasm::types::RecognitionFormat;
|
||||
|
||||
static WORKER_INSTANCE: OnceCell<Arc<ScipixWasm>> = OnceCell::new();
|
||||
|
||||
/// Messages sent from main thread to worker
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum WorkerRequest {
|
||||
/// Initialize the worker
|
||||
Init,
|
||||
|
||||
/// Process an image
|
||||
Process {
|
||||
id: String,
|
||||
image_data: Vec<u8>,
|
||||
format: String,
|
||||
},
|
||||
|
||||
/// Process base64 image
|
||||
ProcessBase64 {
|
||||
id: String,
|
||||
base64: String,
|
||||
format: String,
|
||||
},
|
||||
|
||||
/// Batch process images
|
||||
BatchProcess {
|
||||
id: String,
|
||||
images: Vec<Vec<u8>>,
|
||||
format: String,
|
||||
},
|
||||
|
||||
/// Terminate worker
|
||||
Terminate,
|
||||
}
|
||||
|
||||
/// Messages sent from worker to main thread
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum WorkerResponse {
|
||||
/// Worker is ready
|
||||
Ready,
|
||||
|
||||
/// Processing started
|
||||
Started { id: String },
|
||||
|
||||
/// Processing progress
|
||||
Progress {
|
||||
id: String,
|
||||
processed: usize,
|
||||
total: usize,
|
||||
},
|
||||
|
||||
/// Processing completed successfully
|
||||
Success {
|
||||
id: String,
|
||||
result: serde_json::Value,
|
||||
},
|
||||
|
||||
/// Processing failed
|
||||
Error { id: String, error: String },
|
||||
|
||||
/// Worker terminated
|
||||
Terminated,
|
||||
}
|
||||
|
||||
/// Initialize the worker
|
||||
#[wasm_bindgen(js_name = initWorker)]
|
||||
pub async fn init_worker() -> Result<(), JsValue> {
|
||||
let instance = ScipixWasm::new().await?;
|
||||
WORKER_INSTANCE
|
||||
.set(Arc::new(instance))
|
||||
.map_err(|_| JsValue::from_str("Worker already initialized"))?;
|
||||
|
||||
post_response(WorkerResponse::Ready)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Handle messages from the main thread
|
||||
#[wasm_bindgen(js_name = handleWorkerMessage)]
|
||||
pub async fn handle_worker_message(event: MessageEvent) -> Result<(), JsValue> {
|
||||
let data = event.data();
|
||||
|
||||
let request: WorkerRequest = serde_wasm_bindgen::from_value(data)
|
||||
.map_err(|e| JsValue::from_str(&format!("Invalid message: {}", e)))?;
|
||||
|
||||
match request {
|
||||
WorkerRequest::Init => {
|
||||
init_worker().await?;
|
||||
}
|
||||
|
||||
WorkerRequest::Process {
|
||||
id,
|
||||
image_data,
|
||||
format,
|
||||
} => {
|
||||
process_image(id, image_data, format).await?;
|
||||
}
|
||||
|
||||
WorkerRequest::ProcessBase64 { id, base64, format } => {
|
||||
process_base64(id, base64, format).await?;
|
||||
}
|
||||
|
||||
WorkerRequest::BatchProcess { id, images, format } => {
|
||||
process_batch(id, images, format).await?;
|
||||
}
|
||||
|
||||
WorkerRequest::Terminate => {
|
||||
post_response(WorkerResponse::Terminated)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn process_image(id: String, image_data: Vec<u8>, format: String) -> Result<(), JsValue> {
|
||||
post_response(WorkerResponse::Started { id: id.clone() })?;
|
||||
|
||||
let instance = WORKER_INSTANCE
|
||||
.get()
|
||||
.ok_or_else(|| JsValue::from_str("Worker not initialized"))?;
|
||||
|
||||
let mut worker_instance = ScipixWasm::new().await?;
|
||||
worker_instance.set_format(&format);
|
||||
|
||||
match worker_instance.recognize(&image_data).await {
|
||||
Ok(result) => {
|
||||
let json_result: serde_json::Value = serde_wasm_bindgen::from_value(result)?;
|
||||
post_response(WorkerResponse::Success {
|
||||
id,
|
||||
result: json_result,
|
||||
})?;
|
||||
}
|
||||
Err(e) => {
|
||||
post_response(WorkerResponse::Error {
|
||||
id,
|
||||
error: format!("{:?}", e),
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn process_base64(id: String, base64: String, format: String) -> Result<(), JsValue> {
|
||||
post_response(WorkerResponse::Started { id: id.clone() })?;
|
||||
|
||||
let mut worker_instance = ScipixWasm::new().await?;
|
||||
worker_instance.set_format(&format);
|
||||
|
||||
match worker_instance.recognize_base64(&base64).await {
|
||||
Ok(result) => {
|
||||
let json_result: serde_json::Value = serde_wasm_bindgen::from_value(result)?;
|
||||
post_response(WorkerResponse::Success {
|
||||
id,
|
||||
result: json_result,
|
||||
})?;
|
||||
}
|
||||
Err(e) => {
|
||||
post_response(WorkerResponse::Error {
|
||||
id,
|
||||
error: format!("{:?}", e),
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn process_batch(id: String, images: Vec<Vec<u8>>, format: String) -> Result<(), JsValue> {
|
||||
post_response(WorkerResponse::Started { id: id.clone() })?;
|
||||
|
||||
let total = images.len();
|
||||
let mut results = Vec::new();
|
||||
|
||||
let mut worker_instance = ScipixWasm::new().await?;
|
||||
worker_instance.set_format(&format);
|
||||
|
||||
for (idx, image_data) in images.into_iter().enumerate() {
|
||||
// Report progress
|
||||
post_response(WorkerResponse::Progress {
|
||||
id: id.clone(),
|
||||
processed: idx,
|
||||
total,
|
||||
})?;
|
||||
|
||||
match worker_instance.recognize(&image_data).await {
|
||||
Ok(result) => {
|
||||
let json_result: serde_json::Value = serde_wasm_bindgen::from_value(result)?;
|
||||
results.push(json_result);
|
||||
}
|
||||
Err(e) => {
|
||||
web_sys::console::warn_1(&JsValue::from_str(&format!(
|
||||
"Failed to process image {}: {:?}",
|
||||
idx, e
|
||||
)));
|
||||
results.push(serde_json::Value::Null);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
post_response(WorkerResponse::Success {
|
||||
id,
|
||||
result: serde_json::json!({ "results": results }),
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn post_response(response: WorkerResponse) -> Result<(), JsValue> {
|
||||
let global = js_sys::global().dyn_into::<DedicatedWorkerGlobalScope>()?;
|
||||
let message = serde_wasm_bindgen::to_value(&response)?;
|
||||
global.post_message(&message)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Setup worker message listener
|
||||
#[wasm_bindgen(js_name = setupWorker)]
|
||||
pub fn setup_worker() -> Result<(), JsValue> {
|
||||
let global = js_sys::global().dyn_into::<DedicatedWorkerGlobalScope>()?;
|
||||
|
||||
let closure = Closure::wrap(Box::new(move |event: MessageEvent| {
|
||||
wasm_bindgen_futures::spawn_local(async move {
|
||||
if let Err(e) = handle_worker_message(event).await {
|
||||
web_sys::console::error_1(&e);
|
||||
}
|
||||
});
|
||||
}) as Box<dyn FnMut(MessageEvent)>);
|
||||
|
||||
global.set_onmessage(Some(closure.as_ref().unchecked_ref()));
|
||||
closure.forget(); // Keep closure alive
|
||||
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user