Files
wifi-densepose/vendor/ruvector/examples/scipix/examples/api_server.rs

267 lines
7.9 KiB
Rust

//! API server example
//!
//! This example demonstrates how to create a REST API server for OCR processing.
//! It includes model preloading, graceful shutdown, and health checks.
//!
//! Usage:
//! ```bash
//! cargo run --example api_server
//!
//! # Then in another terminal:
//! curl -X POST -F "image=@equation.png" http://localhost:8080/ocr
//! ```
use axum::{
extract::{Multipart, State},
http::StatusCode,
response::{IntoResponse, Json},
routing::{get, post},
Router,
};
use ruvector_scipix::{OcrConfig, OcrEngine, OutputFormat};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::signal;
use tower_http::cors::CorsLayer;
#[derive(Clone)]
struct AppState {
engine: Arc<OcrEngine>,
}
#[derive(Serialize, Deserialize)]
struct OcrResponse {
success: bool,
text: Option<String>,
latex: Option<String>,
confidence: Option<f32>,
error: Option<String>,
}
#[derive(Serialize)]
struct HealthResponse {
status: String,
version: String,
models_loaded: bool,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("info")).init();
println!("Initializing OCR engine...");
// Configure OCR engine
let config = OcrConfig::default();
let engine = OcrEngine::new(config).await?;
// Preload models for faster first request
println!("Preloading models...");
// TODO: Add model preloading method to OcrEngine
let state = AppState {
engine: Arc::new(engine),
};
// Build router
let app = Router::new()
.route("/health", get(health_check))
.route("/ocr", post(process_ocr))
.route("/batch", post(process_batch))
.layer(CorsLayer::permissive())
.with_state(state);
let addr = "0.0.0.0:8080";
println!("Starting server on http://{}", addr);
println!("\nEndpoints:");
println!(" GET /health - Health check");
println!(" POST /ocr - Process single image");
println!(" POST /batch - Process multiple images");
println!("\nPress Ctrl+C to shutdown");
let listener = tokio::net::TcpListener::bind(addr).await?;
// Run server with graceful shutdown
axum::serve(listener, app)
.with_graceful_shutdown(shutdown_signal())
.await?;
println!("\nServer shutdown complete");
Ok(())
}
async fn health_check() -> impl IntoResponse {
Json(HealthResponse {
status: "healthy".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
models_loaded: true,
})
}
async fn process_ocr(State(state): State<AppState>, mut multipart: Multipart) -> impl IntoResponse {
while let Some(field) = multipart.next_field().await.unwrap() {
if field.name() == Some("image") {
let data = match field.bytes().await {
Ok(bytes) => bytes,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(OcrResponse {
success: false,
text: None,
latex: None,
confidence: None,
error: Some(format!("Failed to read image: {}", e)),
}),
);
}
};
let image = match image::load_from_memory(&data) {
Ok(img) => img,
Err(e) => {
return (
StatusCode::BAD_REQUEST,
Json(OcrResponse {
success: false,
text: None,
latex: None,
confidence: None,
error: Some(format!("Invalid image format: {}", e)),
}),
);
}
};
match state.engine.recognize(&image).await {
Ok(result) => {
return (
StatusCode::OK,
Json(OcrResponse {
success: true,
text: Some(result.text.clone()),
latex: result.to_format(OutputFormat::LaTeX).ok(),
confidence: Some(result.confidence),
error: None,
}),
);
}
Err(e) => {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(OcrResponse {
success: false,
text: None,
latex: None,
confidence: None,
error: Some(format!("OCR failed: {}", e)),
}),
);
}
}
}
}
(
StatusCode::BAD_REQUEST,
Json(OcrResponse {
success: false,
text: None,
latex: None,
confidence: None,
error: Some("No image field found".to_string()),
}),
)
}
async fn process_batch(
State(state): State<AppState>,
mut multipart: Multipart,
) -> impl IntoResponse {
let mut results = Vec::new();
while let Some(field) = multipart.next_field().await.unwrap() {
if field.name() == Some("images") {
let data = match field.bytes().await {
Ok(bytes) => bytes,
Err(e) => {
results.push(OcrResponse {
success: false,
text: None,
latex: None,
confidence: None,
error: Some(format!("Failed to read image: {}", e)),
});
continue;
}
};
let image = match image::load_from_memory(&data) {
Ok(img) => img,
Err(e) => {
results.push(OcrResponse {
success: false,
text: None,
latex: None,
confidence: None,
error: Some(format!("Invalid image: {}", e)),
});
continue;
}
};
match state.engine.recognize(&image).await {
Ok(result) => {
results.push(OcrResponse {
success: true,
text: Some(result.text.clone()),
latex: result.to_format(OutputFormat::LaTeX).ok(),
confidence: Some(result.confidence),
error: None,
});
}
Err(e) => {
results.push(OcrResponse {
success: false,
text: None,
latex: None,
confidence: None,
error: Some(format!("OCR failed: {}", e)),
});
}
}
}
}
(StatusCode::OK, Json(results))
}
async fn shutdown_signal() {
let ctrl_c = async {
signal::ctrl_c()
.await
.expect("Failed to install Ctrl+C handler");
};
#[cfg(unix)]
let terminate = async {
signal::unix::signal(signal::unix::SignalKind::terminate())
.expect("Failed to install signal handler")
.recv()
.await;
};
#[cfg(not(unix))]
let terminate = std::future::pending::<()>();
tokio::select! {
_ = ctrl_c => {
println!("\nReceived Ctrl+C, shutting down gracefully...");
},
_ = terminate => {
println!("\nReceived termination signal, shutting down gracefully...");
},
}
}