Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'

This commit is contained in:
ruv
2026-02-28 14:39:40 -05:00
7854 changed files with 3522914 additions and 0 deletions

View File

@@ -0,0 +1,576 @@
//! Model download functionality with progress tracking and resume support
use super::progress::{ProgressBar, ProgressStyle};
use super::registry::ModelInfo;
use super::{default_cache_dir, get_hf_token, HubError, Result};
use regex::Regex;
use sha2::{Digest, Sha256};
use std::fs::{self, File};
use std::io::{self, BufWriter, Write};
use std::path::{Path, PathBuf};
// ============================================================================
// Security: URL and Input Validation (H-001)
// ============================================================================
/// Allowed domains for HuggingFace downloads
const ALLOWED_DOMAINS: &[&str] = &["huggingface.co", "hf.co", "cdn-lfs.huggingface.co"];
/// Validate URL is from allowed HuggingFace domains
fn validate_url(url: &str) -> Result<()> {
// Parse the URL to extract the host
let url_lower = url.to_lowercase();
// Check for valid HTTPS scheme
if !url_lower.starts_with("https://") {
return Err(HubError::InvalidFormat(
"Only HTTPS URLs are allowed for downloads".to_string(),
));
}
// Extract host from URL
let without_scheme = &url[8..]; // Skip "https://"
let host_end = without_scheme.find('/').unwrap_or(without_scheme.len());
let host = &without_scheme[..host_end];
// Remove port if present
let host = host.split(':').next().unwrap_or(host);
// Check against allowlist
let is_allowed = ALLOWED_DOMAINS
.iter()
.any(|&domain| host == domain || host.ends_with(&format!(".{}", domain)));
if !is_allowed {
return Err(HubError::InvalidFormat(format!(
"URL host '{}' is not in the allowed domains: {:?}",
host, ALLOWED_DOMAINS
)));
}
Ok(())
}
/// Validate repo_id format (prevents CLI injection)
/// Only allows: alphanumeric, /, -, _, .
fn validate_repo_id(repo_id: &str) -> Result<()> {
// Must contain exactly one slash (user/repo format)
let slash_count = repo_id.chars().filter(|&c| c == '/').count();
if slash_count != 1 {
return Err(HubError::InvalidFormat(
"Repository ID must be in format 'username/repo-name'".to_string(),
));
}
// Regex: only allow safe characters
let valid_pattern = Regex::new(r"^[a-zA-Z0-9][a-zA-Z0-9._-]*/[a-zA-Z0-9][a-zA-Z0-9._-]*$")
.expect("Invalid regex pattern");
if !valid_pattern.is_match(repo_id) {
return Err(HubError::InvalidFormat(format!(
"Repository ID '{}' contains invalid characters. Only alphanumeric, /, -, _, . are allowed",
repo_id
)));
}
// Prevent path traversal
if repo_id.contains("..") {
return Err(HubError::InvalidFormat(
"Repository ID cannot contain '..' (path traversal)".to_string(),
));
}
Ok(())
}
/// Canonicalize and validate file path to prevent path traversal
fn validate_and_canonicalize_path(path: &Path, base_dir: &Path) -> Result<PathBuf> {
// Canonicalize both paths
let canonical_base = base_dir
.canonicalize()
.map_err(|e| HubError::Config(format!("Failed to canonicalize base directory: {}", e)))?;
// Create parent directories if needed, then canonicalize
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
// For new files, canonicalize the parent and append filename
let canonical_path = if path.exists() {
path.canonicalize()
.map_err(|e| HubError::Config(format!("Failed to canonicalize path: {}", e)))?
} else if let Some(parent) = path.parent() {
let canonical_parent = parent
.canonicalize()
.map_err(|e| HubError::Config(format!("Failed to canonicalize parent path: {}", e)))?;
canonical_parent.join(
path.file_name()
.ok_or_else(|| HubError::InvalidFormat("Invalid file path".to_string()))?,
)
} else {
return Err(HubError::InvalidFormat("Invalid file path".to_string()));
};
// Ensure the path is within the base directory
if !canonical_path.starts_with(&canonical_base) {
return Err(HubError::InvalidFormat(format!(
"Path '{}' is outside allowed directory '{}'",
canonical_path.display(),
canonical_base.display()
)));
}
Ok(canonical_path)
}
/// Download configuration
#[derive(Debug, Clone)]
pub struct DownloadConfig {
/// Target directory for downloads
pub cache_dir: PathBuf,
/// HuggingFace token for authentication
pub hf_token: Option<String>,
/// Enable resume for interrupted downloads
pub resume: bool,
/// Show progress bar
pub show_progress: bool,
/// Verify checksum after download
pub verify_checksum: bool,
/// Maximum retry attempts
pub max_retries: u32,
}
impl Default for DownloadConfig {
fn default() -> Self {
Self {
cache_dir: default_cache_dir(),
hf_token: get_hf_token(),
resume: true,
show_progress: true,
verify_checksum: true,
max_retries: 3,
}
}
}
/// Download progress information
#[derive(Debug, Clone)]
pub struct DownloadProgress {
/// Total bytes to download
pub total_bytes: u64,
/// Bytes downloaded so far
pub downloaded_bytes: u64,
/// Download speed in bytes/sec
pub speed_bps: f64,
/// Estimated time remaining in seconds
pub eta_seconds: f64,
/// Current stage
pub stage: DownloadStage,
}
/// Download stages
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DownloadStage {
/// Preparing download
Preparing,
/// Downloading file
Downloading,
/// Verifying checksum
Verifying,
/// Complete
Complete,
/// Failed
Failed(String),
}
impl DownloadProgress {
/// Calculate progress percentage
pub fn percentage(&self) -> f32 {
if self.total_bytes == 0 {
0.0
} else {
(self.downloaded_bytes as f64 / self.total_bytes as f64 * 100.0) as f32
}
}
/// Format speed as human-readable string
pub fn speed_str(&self) -> String {
format_bytes_per_sec(self.speed_bps)
}
/// Format ETA as human-readable string
pub fn eta_str(&self) -> String {
format_duration(self.eta_seconds as u64)
}
}
/// Checksum verifier
pub struct ChecksumVerifier {
hasher: Sha256,
bytes_hashed: u64,
}
impl ChecksumVerifier {
/// Create a new checksum verifier
pub fn new() -> Self {
Self {
hasher: Sha256::new(),
bytes_hashed: 0,
}
}
/// Update with new data
pub fn update(&mut self, data: &[u8]) {
self.hasher.update(data);
self.bytes_hashed += data.len() as u64;
}
/// Finalize and get checksum
pub fn finalize(self) -> String {
format!("{:x}", self.hasher.finalize())
}
/// Verify against expected checksum
pub fn verify(self, expected: &str) -> Result<()> {
let actual = self.finalize();
if actual == expected {
Ok(())
} else {
Err(HubError::ChecksumMismatch {
expected: expected.to_string(),
actual,
})
}
}
}
impl Default for ChecksumVerifier {
fn default() -> Self {
Self::new()
}
}
/// Model downloader
pub struct ModelDownloader {
config: DownloadConfig,
}
impl ModelDownloader {
/// Create a new downloader with default config
pub fn new() -> Self {
Self {
config: DownloadConfig::default(),
}
}
/// Create a downloader with custom config
pub fn with_config(config: DownloadConfig) -> Self {
Self { config }
}
/// Download a model by ID from the registry
pub fn download_by_id(&self, model_id: &str) -> Result<PathBuf> {
let registry = super::registry::RuvLtraRegistry::new();
let model_info = registry
.get(model_id)
.ok_or_else(|| HubError::NotFound(model_id.to_string()))?;
self.download(model_info, None)
}
/// Download a model from ModelInfo
pub fn download(&self, model_info: &ModelInfo, target_path: Option<&Path>) -> Result<PathBuf> {
// Determine target path
let path = if let Some(p) = target_path {
p.to_path_buf()
} else {
self.config.cache_dir.join(&model_info.filename)
};
// SECURITY: Validate and canonicalize path to prevent path traversal
let path = validate_and_canonicalize_path(&path, &self.config.cache_dir)?;
// Check if already downloaded
if path.exists() && !self.config.resume {
if self.config.verify_checksum {
if let Some(checksum) = &model_info.checksum {
self.verify_file(&path, checksum)?;
}
}
return Ok(path);
}
// Download the file
let url = model_info.download_url();
// SECURITY: Validate URL is from allowed domains
validate_url(&url)?;
self.download_file(
&url,
&path,
model_info.size_bytes,
model_info.checksum.as_deref(),
)?;
Ok(path)
}
/// Download a file from URL
fn download_file(
&self,
url: &str,
path: &Path,
expected_size: u64,
expected_checksum: Option<&str>,
) -> Result<()> {
// Use curl/wget if available, otherwise fail with helpful message
if self.has_curl() {
self.download_with_curl(url, path, expected_size, expected_checksum)
} else if self.has_wget() {
self.download_with_wget(url, path, expected_size, expected_checksum)
} else {
Err(HubError::Config(
"Download requires curl or wget. Please install: brew install curl (macOS) or apt install curl (Linux)"
.to_string(),
))
}
}
/// Check if curl is available
fn has_curl(&self) -> bool {
std::process::Command::new("which")
.arg("curl")
.output()
.map(|o| o.status.success())
.unwrap_or(false)
}
/// Check if wget is available
fn has_wget(&self) -> bool {
std::process::Command::new("which")
.arg("wget")
.output()
.map(|o| o.status.success())
.unwrap_or(false)
}
/// Download using curl
fn download_with_curl(
&self,
url: &str,
path: &Path,
_expected_size: u64,
expected_checksum: Option<&str>,
) -> Result<()> {
let mut args = vec![
"-L".to_string(), // Follow redirects
"-#".to_string(), // Progress bar
"--fail".to_string(), // Fail on HTTP errors
];
// Add resume flag if enabled
if self.config.resume && path.exists() {
args.push("-C".to_string());
args.push("-".to_string()); // Auto-resume
}
// Add auth token if provided
if let Some(token) = &self.config.hf_token {
args.push("-H".to_string());
args.push(format!("Authorization: Bearer {}", token));
}
args.push("-o".to_string());
args.push(path.to_str().unwrap().to_string());
args.push(url.to_string());
let status = std::process::Command::new("curl")
.args(&args)
.status()
.map_err(|e| HubError::Network(e.to_string()))?;
if !status.success() {
return Err(HubError::Network(format!(
"curl failed with status: {}",
status
)));
}
// Verify checksum if provided
if self.config.verify_checksum {
if let Some(checksum) = expected_checksum {
self.verify_file(path, checksum)?;
}
}
Ok(())
}
/// Download using wget
fn download_with_wget(
&self,
url: &str,
path: &Path,
_expected_size: u64,
expected_checksum: Option<&str>,
) -> Result<()> {
let mut args = vec![
"-q".to_string(), // Quiet
"--show-progress".to_string(), // But show progress
];
// Add resume flag if enabled
if self.config.resume && path.exists() {
args.push("-c".to_string()); // Continue
}
// Add auth token if provided
if let Some(token) = &self.config.hf_token {
args.push("--header".to_string());
args.push(format!("Authorization: Bearer {}", token));
}
args.push("-O".to_string());
args.push(path.to_str().unwrap().to_string());
args.push(url.to_string());
let status = std::process::Command::new("wget")
.args(&args)
.status()
.map_err(|e| HubError::Network(e.to_string()))?;
if !status.success() {
return Err(HubError::Network(format!(
"wget failed with status: {}",
status
)));
}
// Verify checksum if provided
if self.config.verify_checksum {
if let Some(checksum) = expected_checksum {
self.verify_file(path, checksum)?;
}
}
Ok(())
}
/// Verify file checksum
fn verify_file(&self, path: &Path, expected_checksum: &str) -> Result<()> {
use std::io::Read;
let mut file = File::open(path)?;
let mut verifier = ChecksumVerifier::new();
let mut buffer = [0u8; 8192];
loop {
let n = file.read(&mut buffer)?;
if n == 0 {
break;
}
verifier.update(&buffer[..n]);
}
verifier.verify(expected_checksum)
}
}
impl Default for ModelDownloader {
fn default() -> Self {
Self::new()
}
}
/// Download error type
#[derive(Debug, thiserror::Error)]
pub enum DownloadError {
/// HTTP error
#[error("HTTP error: {0}")]
Http(String),
/// IO error
#[error("IO error: {0}")]
Io(#[from] io::Error),
/// Checksum mismatch
#[error("Checksum verification failed")]
ChecksumMismatch,
}
/// Format bytes per second
fn format_bytes_per_sec(bps: f64) -> String {
const KB: f64 = 1024.0;
const MB: f64 = KB * 1024.0;
const GB: f64 = MB * 1024.0;
if bps >= GB {
format!("{:.2} GB/s", bps / GB)
} else if bps >= MB {
format!("{:.2} MB/s", bps / MB)
} else if bps >= KB {
format!("{:.2} KB/s", bps / KB)
} else {
format!("{:.0} B/s", bps)
}
}
/// Format duration in seconds
fn format_duration(secs: u64) -> String {
if secs < 60 {
format!("{}s", secs)
} else if secs < 3600 {
format!("{}m {}s", secs / 60, secs % 60)
} else {
format!("{}h {}m", secs / 3600, (secs % 3600) / 60)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_download_config_default() {
let config = DownloadConfig::default();
assert!(config.resume);
assert!(config.show_progress);
assert!(config.verify_checksum);
}
#[test]
fn test_download_progress() {
let progress = DownloadProgress {
total_bytes: 1000,
downloaded_bytes: 500,
speed_bps: 1024.0 * 1024.0,
eta_seconds: 30.0,
stage: DownloadStage::Downloading,
};
assert_eq!(progress.percentage(), 50.0);
assert!(progress.speed_str().contains("MB/s"));
}
#[test]
fn test_checksum_verifier() {
let mut verifier = ChecksumVerifier::new();
verifier.update(b"hello world");
let checksum = verifier.finalize();
assert!(!checksum.is_empty());
assert_eq!(checksum.len(), 64); // SHA256 hex is 64 chars
}
#[test]
fn test_format_bytes_per_sec() {
assert_eq!(format_bytes_per_sec(500.0), "500 B/s");
assert_eq!(format_bytes_per_sec(1024.0 * 10.0), "10.00 KB/s");
assert_eq!(format_bytes_per_sec(1024.0 * 1024.0 * 5.0), "5.00 MB/s");
}
#[test]
fn test_format_duration() {
assert_eq!(format_duration(30), "30s");
assert_eq!(format_duration(90), "1m 30s");
assert_eq!(format_duration(3700), "1h 1m");
}
}

View File

@@ -0,0 +1,145 @@
//! HuggingFace Hub integration for RuvLTRA model management
//!
//! This module provides comprehensive HuggingFace Hub integration for publishing,
//! downloading, and managing RuvLTRA models. It supports:
//!
//! - **Model Upload**: Push GGUF files and SONA weights to HF Hub
//! - **Model Download**: Pull models with automatic quantization selection
//! - **Model Registry**: Pre-configured RuvLTRA model collection
//! - **Progress Tracking**: Visual progress bars with resume support
//! - **Integrity Verification**: Checksum validation for downloads
//!
//! # Example
//!
//! ```rust,ignore
//! use ruvllm::hub::{RuvLtraRegistry, ModelDownloader};
//!
//! // Download a model
//! let registry = RuvLtraRegistry::new();
//! let model_info = registry.get("ruvltra-small")?;
//! let downloader = ModelDownloader::new();
//! let path = downloader.download(model_info, None).await?;
//!
//! // Upload a model
//! let uploader = ModelUploader::new("hf_token_here");
//! uploader.upload(
//! "./my-model.gguf",
//! "username/my-ruvltra",
//! Some("My custom RuvLTRA model"),
//! ).await?;
//! ```
pub mod download;
pub mod model_card;
pub mod progress;
pub mod registry;
pub mod upload;
// Re-exports
pub use download::{
ChecksumVerifier, DownloadConfig, DownloadError, DownloadProgress, ModelDownloader,
};
pub use model_card::{
DatasetInfo, Framework, License, MetricResult, ModelCard, ModelCardBuilder, TaskType,
};
pub use progress::{
MultiProgress, ProgressBar, ProgressCallback, ProgressIndicator, ProgressStyle,
};
pub use registry::{
get_model_info, HardwareRequirements, ModelInfo, ModelSize, QuantizationLevel, RuvLtraRegistry,
};
pub use upload::{ModelMetadata, ModelUploader, UploadConfig, UploadError, UploadProgress};
use std::path::PathBuf;
/// Result type for hub operations
pub type Result<T> = std::result::Result<T, HubError>;
/// Hub operation errors
#[derive(Debug, thiserror::Error)]
pub enum HubError {
/// IO error
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
/// HTTP error
#[cfg(feature = "async-runtime")]
#[error("HTTP error: {0}")]
Http(String),
/// Authentication error
#[error("Authentication failed: {0}")]
Auth(String),
/// Model not found
#[error("Model not found: {0}")]
NotFound(String),
/// Checksum mismatch
#[error("Checksum verification failed: expected {expected}, got {actual}")]
ChecksumMismatch { expected: String, actual: String },
/// Invalid model format
#[error("Invalid model format: {0}")]
InvalidFormat(String),
/// Rate limit exceeded
#[error("Rate limit exceeded. Retry after {0} seconds")]
RateLimit(u64),
/// Network error
#[error("Network error: {0}")]
Network(String),
/// Parse error
#[error("Parse error: {0}")]
Parse(String),
/// Configuration error
#[error("Configuration error: {0}")]
Config(String),
}
/// Default HuggingFace Hub API endpoint
pub const HF_ENDPOINT: &str = "https://huggingface.co";
/// Default cache directory for downloaded models
pub fn default_cache_dir() -> PathBuf {
dirs::cache_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("huggingface")
.join("ruvltra")
}
/// Get HuggingFace token from environment
pub fn get_hf_token() -> Option<String> {
std::env::var("HF_TOKEN")
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
.or_else(|_| std::env::var("HUGGINGFACE_API_KEY"))
.ok()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_cache_dir() {
let cache_dir = default_cache_dir();
assert!(cache_dir.to_string_lossy().contains("huggingface"));
assert!(cache_dir.to_string_lossy().contains("ruvltra"));
}
#[test]
fn test_error_display() {
let err = HubError::NotFound("model-123".to_string());
assert_eq!(err.to_string(), "Model not found: model-123");
let err = HubError::ChecksumMismatch {
expected: "abc123".to_string(),
actual: "def456".to_string(),
};
assert!(err.to_string().contains("abc123"));
assert!(err.to_string().contains("def456"));
}
}

View File

@@ -0,0 +1,426 @@
//! Model card generation for HuggingFace Hub
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Model task type
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum TaskType {
TextGeneration,
ConversationalAi,
CodeCompletion,
QuestionAnswering,
Summarization,
}
/// ML framework
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Framework {
Gguf,
PyTorch,
TensorFlow,
Onnx,
}
/// Model license
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum License {
Mit,
Apache20,
Gpl30,
Bsd3Clause,
CreativemlOpenrailM,
Llama2,
Other(String),
}
impl std::str::FromStr for License {
type Err = ();
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"mit" => Ok(Self::Mit),
"apache-2.0" | "apache2.0" => Ok(Self::Apache20),
"gpl-3.0" | "gpl3.0" => Ok(Self::Gpl30),
"bsd-3-clause" => Ok(Self::Bsd3Clause),
"creativeml-openrail-m" => Ok(Self::CreativemlOpenrailM),
"llama2" => Ok(Self::Llama2),
other => Ok(Self::Other(other.to_string())),
}
}
}
/// Dataset information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatasetInfo {
/// Dataset name/identifier
pub name: String,
/// Dataset description
pub description: Option<String>,
}
/// Metric result
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MetricResult {
/// Metric name (e.g., "perplexity", "accuracy")
pub name: String,
/// Metric value
pub value: f64,
/// Dataset used for evaluation
pub dataset: Option<String>,
}
/// Model card for HuggingFace Hub
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelCard {
/// Model name
pub name: String,
/// Short description
pub description: Option<String>,
/// Task type
pub task: TaskType,
/// Framework
pub framework: Framework,
/// Architecture (e.g., "llama", "qwen2")
pub architecture: String,
/// Model license
pub license: License,
/// Number of parameters
pub parameters: u64,
/// Context window size
pub context_length: usize,
/// Training datasets
pub datasets: Vec<DatasetInfo>,
/// Evaluation metrics
pub metrics: Vec<MetricResult>,
/// Model tags
pub tags: Vec<String>,
/// Additional metadata
pub metadata: HashMap<String, String>,
}
impl ModelCard {
/// Convert model card to YAML frontmatter + markdown
pub fn to_markdown(&self) -> String {
let mut content = String::new();
// YAML frontmatter
content.push_str("---\n");
content.push_str(&format!("language: en\n"));
content.push_str(&format!("license: {}\n", self.license_str()));
content.push_str(&format!("library_name: ruvltra\n"));
if !self.tags.is_empty() {
content.push_str("tags:\n");
for tag in &self.tags {
content.push_str(&format!("- {}\n", tag));
}
}
content.push_str("---\n\n");
// Model description
content.push_str(&format!("# {}\n\n", self.name));
if let Some(desc) = &self.description {
content.push_str(&format!("{}\n\n", desc));
}
// Model details
content.push_str("## Model Details\n\n");
content.push_str(&format!("- **Architecture**: {}\n", self.architecture));
content.push_str(&format!(
"- **Parameters**: {}\n",
format_params(self.parameters)
));
content.push_str(&format!(
"- **Context Length**: {} tokens\n",
self.context_length
));
content.push_str(&format!("- **Framework**: {:?}\n", self.framework));
content.push_str(&format!("- **Task**: {:?}\n\n", self.task));
// Training data
if !self.datasets.is_empty() {
content.push_str("## Training Data\n\n");
for dataset in &self.datasets {
content.push_str(&format!("- **{}**", dataset.name));
if let Some(desc) = &dataset.description {
content.push_str(&format!(": {}", desc));
}
content.push_str("\n");
}
content.push_str("\n");
}
// Evaluation metrics
if !self.metrics.is_empty() {
content.push_str("## Evaluation\n\n");
content.push_str("| Metric | Value | Dataset |\n");
content.push_str("|--------|-------|----------|\n");
for metric in &self.metrics {
content.push_str(&format!(
"| {} | {:.2} | {} |\n",
metric.name,
metric.value,
metric.dataset.as_deref().unwrap_or("N/A")
));
}
content.push_str("\n");
}
// Usage
content.push_str("## Usage\n\n");
content.push_str("```bash\n");
content.push_str("# Download using ruvllm CLI\n");
content.push_str(&format!("ruvllm pull {}\n", self.name.to_lowercase()));
content.push_str("```\n\n");
content.push_str("```rust\n");
content.push_str("use ruvllm::hub::ModelDownloader;\n\n");
content.push_str("let downloader = ModelDownloader::new();\n");
content.push_str(&format!(
"let path = downloader.download_by_id(\"{}\")?;\n",
self.name.to_lowercase()
));
content.push_str("```\n\n");
// Additional metadata
if !self.metadata.is_empty() {
content.push_str("## Additional Information\n\n");
for (key, value) in &self.metadata {
content.push_str(&format!("- **{}**: {}\n", key, value));
}
content.push_str("\n");
}
// Footer
content.push_str("---\n\n");
content.push_str("*This model card was generated automatically by RuvLLM*\n");
content
}
/// Get license as string
fn license_str(&self) -> &str {
match &self.license {
License::Mit => "mit",
License::Apache20 => "apache-2.0",
License::Gpl30 => "gpl-3.0",
License::Bsd3Clause => "bsd-3-clause",
License::CreativemlOpenrailM => "creativeml-openrail-m",
License::Llama2 => "llama2",
License::Other(s) => s,
}
}
}
/// Model card builder
pub struct ModelCardBuilder {
name: String,
description: Option<String>,
task: TaskType,
framework: Framework,
architecture: String,
license: License,
parameters: u64,
context_length: usize,
datasets: Vec<DatasetInfo>,
metrics: Vec<MetricResult>,
tags: Vec<String>,
metadata: HashMap<String, String>,
}
impl ModelCardBuilder {
/// Create a new model card builder
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
description: None,
task: TaskType::TextGeneration,
framework: Framework::Gguf,
architecture: "llama".to_string(),
license: License::Mit,
parameters: 0,
context_length: 4096,
datasets: Vec::new(),
metrics: Vec::new(),
tags: Vec::new(),
metadata: HashMap::new(),
}
}
/// Set description
pub fn description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
/// Set task type
pub fn task(mut self, task: TaskType) -> Self {
self.task = task;
self
}
/// Set framework
pub fn framework(mut self, framework: Framework) -> Self {
self.framework = framework;
self
}
/// Set architecture
pub fn architecture(mut self, arch: impl Into<String>) -> Self {
self.architecture = arch.into();
self
}
/// Set license
pub fn license(mut self, license: License) -> Self {
self.license = license;
self
}
/// Set parameter count
pub fn parameters(mut self, params: u64) -> Self {
self.parameters = params;
self
}
/// Set context length
pub fn context_length(mut self, length: usize) -> Self {
self.context_length = length;
self
}
/// Add a dataset
pub fn add_dataset(mut self, name: impl Into<String>, desc: Option<String>) -> Self {
self.datasets.push(DatasetInfo {
name: name.into(),
description: desc,
});
self
}
/// Add a metric
pub fn add_metric(
mut self,
name: impl Into<String>,
value: f64,
dataset: Option<String>,
) -> Self {
self.metrics.push(MetricResult {
name: name.into(),
value,
dataset,
});
self
}
/// Add a tag
pub fn add_tag(mut self, tag: impl Into<String>) -> Self {
self.tags.push(tag.into());
self
}
/// Add metadata
pub fn add_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.metadata.insert(key.into(), value.into());
self
}
/// Build the model card
pub fn build(self) -> ModelCard {
ModelCard {
name: self.name,
description: self.description,
task: self.task,
framework: self.framework,
architecture: self.architecture,
license: self.license,
parameters: self.parameters,
context_length: self.context_length,
datasets: self.datasets,
metrics: self.metrics,
tags: self.tags,
metadata: self.metadata,
}
}
}
/// Format parameter count as human-readable string
fn format_params(params: u64) -> String {
const B: u64 = 1_000_000_000;
const M: u64 = 1_000_000;
const K: u64 = 1_000;
if params >= B {
format!("{:.1}B", params as f64 / B as f64)
} else if params >= M {
format!("{:.0}M", params as f64 / M as f64)
} else if params >= K {
format!("{:.0}K", params as f64 / K as f64)
} else {
format!("{}", params)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_card_builder() {
let card = ModelCardBuilder::new("Test Model")
.description("A test model")
.architecture("llama")
.parameters(500_000_000)
.context_length(4096)
.add_tag("test")
.build();
assert_eq!(card.name, "Test Model");
assert_eq!(card.parameters, 500_000_000);
assert_eq!(card.tags.len(), 1);
}
#[test]
fn test_model_card_markdown() {
let card = ModelCardBuilder::new("RuvLTRA Small")
.description("Compact model")
.parameters(500_000_000)
.add_dataset("dataset1", Some("Training data".to_string()))
.add_metric("perplexity", 5.2, Some("test-set".to_string()))
.build();
let markdown = card.to_markdown();
assert!(markdown.contains("# RuvLTRA Small"));
assert!(markdown.contains("0.5B"));
assert!(markdown.contains("dataset1"));
assert!(markdown.contains("perplexity"));
}
#[test]
fn test_format_params() {
assert_eq!(format_params(500), "500");
assert_eq!(format_params(5_000), "5K");
assert_eq!(format_params(5_000_000), "5M");
assert_eq!(format_params(500_000_000), "0.5B");
assert_eq!(format_params(3_000_000_000), "3.0B");
}
#[test]
fn test_license_from_str() {
use std::str::FromStr;
assert_eq!(License::from_str("mit").unwrap(), License::Mit);
assert_eq!(License::from_str("apache-2.0").unwrap(), License::Apache20);
match License::from_str("custom-license").unwrap() {
License::Other(s) => assert_eq!(s, "custom-license"),
_ => panic!("Expected Other variant"),
}
}
}

View File

@@ -0,0 +1,298 @@
//! Progress tracking for download and upload operations
use std::time::{Duration, Instant};
/// Progress bar styles
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ProgressStyle {
/// Simple bar: [=====> ] 50%
Bar,
/// Detailed: [=====> ] 50% (5.2 MB/s, ETA: 30s)
Detailed,
/// Minimal: 50% complete
Minimal,
}
/// Progress indicator for terminal output
pub struct ProgressBar {
/// Total bytes
total: u64,
/// Current bytes
current: u64,
/// Start time
start_time: Instant,
/// Last update time
last_update: Instant,
/// Progress style
style: ProgressStyle,
/// Bar width
width: usize,
/// Show in terminal
enabled: bool,
}
impl ProgressBar {
/// Create a new progress bar
pub fn new(total: u64) -> Self {
Self {
total,
current: 0,
start_time: Instant::now(),
last_update: Instant::now(),
style: ProgressStyle::Detailed,
width: 40,
enabled: true,
}
}
/// Set progress style
pub fn with_style(mut self, style: ProgressStyle) -> Self {
self.style = style;
self
}
/// Set bar width
pub fn with_width(mut self, width: usize) -> Self {
self.width = width;
self
}
/// Enable or disable output
pub fn enabled(mut self, enabled: bool) -> Self {
self.enabled = enabled;
self
}
/// Update progress
pub fn update(&mut self, current: u64) {
self.current = current;
self.last_update = Instant::now();
if self.enabled {
self.render();
}
}
/// Increment progress
pub fn inc(&mut self, delta: u64) {
self.update(self.current + delta);
}
/// Finish progress bar
pub fn finish(&mut self) {
self.current = self.total;
if self.enabled {
self.render();
println!(); // New line after completion
}
}
/// Render progress bar to terminal
fn render(&self) {
let percentage = if self.total == 0 {
0.0
} else {
(self.current as f64 / self.total as f64) * 100.0
};
match self.style {
ProgressStyle::Bar => {
let filled = ((percentage / 100.0) * self.width as f64) as usize;
let bar = format!(
"[{}>{}] {:.0}%",
"=".repeat(filled),
" ".repeat(self.width.saturating_sub(filled)),
percentage
);
print!("\r{}", bar);
}
ProgressStyle::Detailed => {
let filled = ((percentage / 100.0) * self.width as f64) as usize;
let speed = self.calculate_speed();
let eta = self.calculate_eta();
let bar = format!(
"[{}>{}] {:.0}% ({}, ETA: {})",
"=".repeat(filled),
" ".repeat(self.width.saturating_sub(filled)),
percentage,
format_speed(speed),
format_duration(eta)
);
print!("\r{}", bar);
}
ProgressStyle::Minimal => {
print!("\r{:.0}% complete", percentage);
}
}
use std::io::{self, Write};
let _ = io::stdout().flush();
}
/// Calculate download/upload speed in bytes/sec
fn calculate_speed(&self) -> f64 {
let elapsed = self.start_time.elapsed().as_secs_f64();
if elapsed > 0.0 {
self.current as f64 / elapsed
} else {
0.0
}
}
/// Calculate estimated time remaining
fn calculate_eta(&self) -> Duration {
let remaining = self.total.saturating_sub(self.current);
let speed = self.calculate_speed();
if speed > 0.0 {
let seconds = remaining as f64 / speed;
Duration::from_secs_f64(seconds)
} else {
Duration::from_secs(0)
}
}
}
/// Format speed as human-readable string
fn format_speed(bps: f64) -> String {
const KB: f64 = 1024.0;
const MB: f64 = KB * 1024.0;
const GB: f64 = MB * 1024.0;
if bps >= GB {
format!("{:.2} GB/s", bps / GB)
} else if bps >= MB {
format!("{:.2} MB/s", bps / MB)
} else if bps >= KB {
format!("{:.2} KB/s", bps / KB)
} else {
format!("{:.0} B/s", bps)
}
}
/// Format duration as human-readable string
fn format_duration(d: Duration) -> String {
let secs = d.as_secs();
if secs < 60 {
format!("{}s", secs)
} else if secs < 3600 {
format!("{}m {}s", secs / 60, secs % 60)
} else {
format!("{}h {}m", secs / 3600, (secs % 3600) / 60)
}
}
/// Progress callback function type
pub type ProgressCallback = Box<dyn Fn(u64, u64) + Send + Sync>;
/// Progress indicator trait
pub trait ProgressIndicator {
/// Update progress
fn update(&mut self, current: u64, total: u64);
/// Finish progress
fn finish(&mut self);
}
impl ProgressIndicator for ProgressBar {
fn update(&mut self, current: u64, _total: u64) {
self.update(current);
}
fn finish(&mut self) {
self.finish();
}
}
/// Multi-progress manager for multiple concurrent operations
pub struct MultiProgress {
bars: Vec<ProgressBar>,
}
impl MultiProgress {
/// Create a new multi-progress manager
pub fn new() -> Self {
Self { bars: Vec::new() }
}
/// Add a progress bar
pub fn add(&mut self, bar: ProgressBar) -> usize {
let id = self.bars.len();
self.bars.push(bar);
id
}
/// Update a specific progress bar
pub fn update(&mut self, id: usize, current: u64) {
if let Some(bar) = self.bars.get_mut(id) {
bar.update(current);
}
}
/// Finish all progress bars
pub fn finish_all(&mut self) {
for bar in &mut self.bars {
bar.finish();
}
}
}
impl Default for MultiProgress {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_progress_bar_creation() {
let pb = ProgressBar::new(1000);
assert_eq!(pb.total, 1000);
assert_eq!(pb.current, 0);
}
#[test]
fn test_progress_update() {
let mut pb = ProgressBar::new(1000).enabled(false);
pb.update(500);
assert_eq!(pb.current, 500);
}
#[test]
fn test_progress_increment() {
let mut pb = ProgressBar::new(1000).enabled(false);
pb.inc(100);
pb.inc(100);
assert_eq!(pb.current, 200);
}
#[test]
fn test_format_speed() {
assert_eq!(format_speed(500.0), "500 B/s");
assert_eq!(format_speed(1024.0 * 10.0), "10.00 KB/s");
assert_eq!(format_speed(1024.0 * 1024.0 * 5.0), "5.00 MB/s");
}
#[test]
fn test_format_duration() {
assert_eq!(format_duration(Duration::from_secs(30)), "30s");
assert_eq!(format_duration(Duration::from_secs(90)), "1m 30s");
assert_eq!(format_duration(Duration::from_secs(3700)), "1h 1m");
}
#[test]
fn test_multi_progress() {
let mut mp = MultiProgress::new();
let id1 = mp.add(ProgressBar::new(100).enabled(false));
let id2 = mp.add(ProgressBar::new(200).enabled(false));
mp.update(id1, 50);
mp.update(id2, 100);
assert_eq!(mp.bars[id1].current, 50);
assert_eq!(mp.bars[id2].current, 100);
}
}

View File

@@ -0,0 +1,443 @@
//! RuvLTRA model registry with pre-configured models
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
/// Model size category
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ModelSize {
/// Tiny models (< 1B parameters)
Tiny,
/// Small models (0.5B - 1B parameters)
Small,
/// Medium models (1B - 5B parameters)
Medium,
/// Large models (5B - 10B parameters)
Large,
}
/// Quantization level
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum QuantizationLevel {
/// 4-bit quantization (smallest, ~662MB for 0.5B model)
Q4,
/// 5-bit quantization (balanced)
Q5,
/// 8-bit quantization (highest quality)
Q8,
/// FP16 (no quantization)
FP16,
}
impl QuantizationLevel {
/// Get file size multiplier relative to FP16
pub fn size_multiplier(&self) -> f32 {
match self {
Self::Q4 => 0.25,
Self::Q5 => 0.3125,
Self::Q8 => 0.5,
Self::FP16 => 1.0,
}
}
/// Get expected memory reduction
pub fn memory_reduction(&self) -> f32 {
match self {
Self::Q4 => 0.75, // 75% reduction
Self::Q5 => 0.69, // 69% reduction
Self::Q8 => 0.50, // 50% reduction
Self::FP16 => 0.0, // No reduction
}
}
}
/// Hardware requirements for model execution
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HardwareRequirements {
/// Minimum RAM in GB
pub min_ram_gb: f32,
/// Recommended RAM in GB
pub recommended_ram_gb: f32,
/// Supports Apple Neural Engine
pub supports_ane: bool,
/// Supports Metal GPU acceleration
pub supports_metal: bool,
/// Supports CUDA
pub supports_cuda: bool,
/// Minimum GPU VRAM in GB (if using GPU)
pub min_vram_gb: Option<f32>,
}
/// Model information in the registry
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
/// Model identifier (e.g., "ruvltra-small")
pub id: String,
/// Display name
pub name: String,
/// HuggingFace repository (e.g., "ruvnet/ruvltra-small")
pub repo: String,
/// Model filename on HF Hub
pub filename: String,
/// Model size category
pub size: ModelSize,
/// Quantization level
pub quantization: QuantizationLevel,
/// File size in bytes
pub size_bytes: u64,
/// SHA256 checksum
pub checksum: Option<String>,
/// Number of parameters (in billions)
pub params_b: f32,
/// Context window size
pub context_length: usize,
/// Hardware requirements
pub hardware: HardwareRequirements,
/// Model description
pub description: String,
/// Whether this is a LoRA adapter
pub is_adapter: bool,
/// Base model required (for adapters)
pub base_model: Option<String>,
/// Includes SONA pre-trained weights
pub has_sona_weights: bool,
}
impl ModelInfo {
/// Get download URL for this model
pub fn download_url(&self) -> String {
format!(
"https://huggingface.co/{}/resolve/main/{}",
self.repo, self.filename
)
}
/// Get HuggingFace Hub page URL
pub fn hub_url(&self) -> String {
format!("https://huggingface.co/{}", self.repo)
}
/// Estimate download time in seconds at given speed (MB/s)
pub fn estimate_download_time(&self, speed_mbps: f32) -> f32 {
let size_mb = self.size_bytes as f32 / (1024.0 * 1024.0);
size_mb / speed_mbps
}
/// Check if model fits in available RAM
pub fn fits_in_ram(&self, available_gb: f32) -> bool {
available_gb >= self.hardware.min_ram_gb
}
}
/// RuvLTRA model registry
pub struct RuvLtraRegistry {
models: HashMap<String, ModelInfo>,
}
impl RuvLtraRegistry {
/// Create a new registry with pre-configured models
pub fn new() -> Self {
let mut models = HashMap::new();
// RuvLTRA-Small (0.5B) - Q4 quantization
models.insert(
"ruvltra-small".to_string(),
ModelInfo {
id: "ruvltra-small".to_string(),
name: "RuvLTRA Small (0.5B Q4)".to_string(),
repo: "ruv/ruvltra".to_string(),
filename: "ruvltra-small-0.5b-q4_k_m.gguf".to_string(),
size: ModelSize::Small,
quantization: QuantizationLevel::Q4,
size_bytes: 662_000_000, // ~662MB
checksum: None, // Set after publishing
params_b: 0.5,
context_length: 4096,
hardware: HardwareRequirements {
min_ram_gb: 1.0,
recommended_ram_gb: 2.0,
supports_ane: true,
supports_metal: true,
supports_cuda: true,
min_vram_gb: Some(1.0),
},
description: "Compact RuvLTRA model optimized for edge devices. \
Includes SONA pre-trained weights for adaptive learning."
.to_string(),
is_adapter: false,
base_model: None,
has_sona_weights: true,
},
);
// RuvLTRA-Small (0.5B) - Q8 quantization
models.insert(
"ruvltra-small-q8".to_string(),
ModelInfo {
id: "ruvltra-small-q8".to_string(),
name: "RuvLTRA Small (0.5B Q8)".to_string(),
repo: "ruv/ruvltra".to_string(),
filename: "ruvltra-small-0.5b-q8_0.gguf".to_string(),
size: ModelSize::Small,
quantization: QuantizationLevel::Q8,
size_bytes: 1_324_000_000, // ~1.3GB
checksum: None,
params_b: 0.5,
context_length: 4096,
hardware: HardwareRequirements {
min_ram_gb: 2.0,
recommended_ram_gb: 4.0,
supports_ane: true,
supports_metal: true,
supports_cuda: true,
min_vram_gb: Some(2.0),
},
description: "High-quality Q8 quantization for better accuracy.".to_string(),
is_adapter: false,
base_model: None,
has_sona_weights: true,
},
);
// RuvLTRA-Medium (3B) - Q4 quantization
models.insert(
"ruvltra-medium".to_string(),
ModelInfo {
id: "ruvltra-medium".to_string(),
name: "RuvLTRA Medium (3B Q4)".to_string(),
repo: "ruv/ruvltra".to_string(),
filename: "ruvltra-medium-1.1b-q4_k_m.gguf".to_string(),
size: ModelSize::Medium,
quantization: QuantizationLevel::Q4,
size_bytes: 2_100_000_000, // ~2.1GB
checksum: None,
params_b: 3.0,
context_length: 8192,
hardware: HardwareRequirements {
min_ram_gb: 4.0,
recommended_ram_gb: 8.0,
supports_ane: true,
supports_metal: true,
supports_cuda: true,
min_vram_gb: Some(4.0),
},
description: "Balanced RuvLTRA model for general-purpose tasks. \
Extended context window with SONA learning."
.to_string(),
is_adapter: false,
base_model: None,
has_sona_weights: true,
},
);
// RuvLTRA-Medium (3B) - Q8 quantization
models.insert(
"ruvltra-medium-q8".to_string(),
ModelInfo {
id: "ruvltra-medium-q8".to_string(),
name: "RuvLTRA Medium (3B Q8)".to_string(),
repo: "ruv/ruvltra".to_string(),
filename: "ruvltra-medium-1.1b-q8_0.gguf".to_string(),
size: ModelSize::Medium,
quantization: QuantizationLevel::Q8,
size_bytes: 4_200_000_000, // ~4.2GB
checksum: None,
params_b: 3.0,
context_length: 8192,
hardware: HardwareRequirements {
min_ram_gb: 6.0,
recommended_ram_gb: 12.0,
supports_ane: true,
supports_metal: true,
supports_cuda: true,
min_vram_gb: Some(6.0),
},
description: "High-quality Medium model with Q8 quantization.".to_string(),
is_adapter: false,
base_model: None,
has_sona_weights: true,
},
);
// RuvLTRA-Small-Coder (LoRA adapter)
models.insert(
"ruvltra-small-coder".to_string(),
ModelInfo {
id: "ruvltra-small-coder".to_string(),
name: "RuvLTRA Small Coder (LoRA)".to_string(),
repo: "ruv/ruvltra".to_string(),
filename: "ruvltra-small-coder-lora.safetensors".to_string(),
size: ModelSize::Tiny,
quantization: QuantizationLevel::FP16,
size_bytes: 50_000_000, // ~50MB (LoRA is small)
checksum: None,
params_b: 0.05, // Adapter parameters
context_length: 4096,
hardware: HardwareRequirements {
min_ram_gb: 0.1,
recommended_ram_gb: 0.5,
supports_ane: true,
supports_metal: true,
supports_cuda: true,
min_vram_gb: None,
},
description: "LoRA adapter for code completion. \
Requires ruvltra-small or ruvltra-small-q8 base model."
.to_string(),
is_adapter: true,
base_model: Some("ruvltra-small".to_string()),
has_sona_weights: false,
},
);
Self { models }
}
/// Get model info by ID
pub fn get(&self, id: &str) -> Option<&ModelInfo> {
self.models.get(id)
}
/// Get all available models
pub fn list_all(&self) -> Vec<&ModelInfo> {
self.models.values().collect()
}
/// Get models by size
pub fn list_by_size(&self, size: ModelSize) -> Vec<&ModelInfo> {
self.models.values().filter(|m| m.size == size).collect()
}
/// Get base models (exclude adapters)
pub fn list_base_models(&self) -> Vec<&ModelInfo> {
self.models.values().filter(|m| !m.is_adapter).collect()
}
/// Get adapters for a specific base model
pub fn list_adapters(&self, base_model: &str) -> Vec<&ModelInfo> {
self.models
.values()
.filter(|m| {
m.is_adapter
&& m.base_model
.as_ref()
.map(|b| b == base_model)
.unwrap_or(false)
})
.collect()
}
/// Recommend model based on available RAM
pub fn recommend_for_ram(&self, available_gb: f32) -> Option<&ModelInfo> {
let mut candidates: Vec<_> = self
.models
.values()
.filter(|m| !m.is_adapter && m.fits_in_ram(available_gb))
.collect();
// Sort by parameters (largest that fits)
candidates.sort_by(|a, b| b.params_b.partial_cmp(&a.params_b).unwrap());
candidates.first().copied()
}
/// Get model IDs
pub fn model_ids(&self) -> Vec<String> {
self.models.keys().cloned().collect()
}
}
impl Default for RuvLtraRegistry {
fn default() -> Self {
Self::new()
}
}
/// Get model info by ID (convenience function)
pub fn get_model_info(id: &str) -> Option<ModelInfo> {
RuvLtraRegistry::new().get(id).cloned()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_registry_initialization() {
let registry = RuvLtraRegistry::new();
assert!(registry.get("ruvltra-small").is_some());
assert!(registry.get("ruvltra-medium").is_some());
assert!(registry.get("nonexistent").is_none());
}
#[test]
fn test_model_info() {
let registry = RuvLtraRegistry::new();
let model = registry.get("ruvltra-small").unwrap();
assert_eq!(model.params_b, 0.5);
assert_eq!(model.quantization, QuantizationLevel::Q4);
assert!(model.has_sona_weights);
assert!(!model.is_adapter);
}
#[test]
fn test_list_by_size() {
let registry = RuvLtraRegistry::new();
let small_models = registry.list_by_size(ModelSize::Small);
assert!(!small_models.is_empty());
}
#[test]
fn test_adapters() {
let registry = RuvLtraRegistry::new();
let adapters = registry.list_adapters("ruvltra-small");
assert!(!adapters.is_empty());
assert!(adapters[0].is_adapter);
}
#[test]
fn test_ram_recommendation() {
let registry = RuvLtraRegistry::new();
// Should recommend small model for 2GB
let model = registry.recommend_for_ram(2.0);
assert!(model.is_some());
assert!(model.unwrap().params_b <= 1.0);
// Should recommend medium model for 8GB
let model = registry.recommend_for_ram(8.0);
assert!(model.is_some());
}
#[test]
fn test_quantization_multipliers() {
assert_eq!(QuantizationLevel::Q4.size_multiplier(), 0.25);
assert_eq!(QuantizationLevel::Q8.size_multiplier(), 0.5);
assert_eq!(QuantizationLevel::FP16.size_multiplier(), 1.0);
}
#[test]
fn test_model_urls() {
let registry = RuvLtraRegistry::new();
let model = registry.get("ruvltra-small").unwrap();
let url = model.download_url();
assert!(url.contains("huggingface.co"));
assert!(url.contains("ruv/ruvltra"));
assert!(url.contains(".gguf"));
let hub_url = model.hub_url();
assert_eq!(hub_url, "https://huggingface.co/ruv/ruvltra");
}
#[test]
fn test_download_time_estimation() {
let registry = RuvLtraRegistry::new();
let model = registry.get("ruvltra-small").unwrap();
// At 10 MB/s, should take ~66 seconds
let time = model.estimate_download_time(10.0);
assert!(time > 60.0 && time < 70.0);
}
}

View File

@@ -0,0 +1,440 @@
//! Model upload functionality for publishing to HuggingFace Hub
use super::model_card::{ModelCard, ModelCardBuilder};
use super::{get_hf_token, HubError, Result};
use regex::Regex;
use std::fs;
use std::path::{Path, PathBuf};
// ============================================================================
// Security: Input Validation (H-002)
// ============================================================================
/// Validate repo_id format (prevents CLI injection)
/// Only allows: alphanumeric, /, -, _, .
fn validate_repo_id(repo_id: &str) -> Result<()> {
// Must contain exactly one slash (user/repo format)
let slash_count = repo_id.chars().filter(|&c| c == '/').count();
if slash_count != 1 {
return Err(HubError::InvalidFormat(
"Repository ID must be in format 'username/repo-name'".to_string(),
));
}
// Regex: only allow safe characters
let valid_pattern = Regex::new(r"^[a-zA-Z0-9][a-zA-Z0-9._-]*/[a-zA-Z0-9][a-zA-Z0-9._-]*$")
.expect("Invalid regex pattern");
if !valid_pattern.is_match(repo_id) {
return Err(HubError::InvalidFormat(format!(
"Repository ID '{}' contains invalid characters. Only alphanumeric, /, -, _, . are allowed",
repo_id
)));
}
// Prevent path traversal
if repo_id.contains("..") {
return Err(HubError::InvalidFormat(
"Repository ID cannot contain '..' (path traversal)".to_string(),
));
}
// Prevent shell metacharacters that could be used for injection
let dangerous_chars = [
'`', '$', '(', ')', ';', '&', '|', '<', '>', '\n', '\r', '"', '\'', '\\',
];
for c in dangerous_chars {
if repo_id.contains(c) {
return Err(HubError::InvalidFormat(format!(
"Repository ID cannot contain shell metacharacter '{}'",
c
)));
}
}
Ok(())
}
/// Validate file path for upload (prevents path traversal)
fn validate_upload_path(path: &Path) -> Result<()> {
let path_str = path.to_string_lossy();
// Prevent path traversal
if path_str.contains("..") {
return Err(HubError::InvalidFormat(
"File path cannot contain '..' (path traversal)".to_string(),
));
}
// Canonicalize to resolve any symlinks and verify it exists
let canonical = path.canonicalize().map_err(|e| {
HubError::NotFound(format!("Cannot resolve path '{}': {}", path.display(), e))
})?;
// Verify the file exists and is a regular file
if !canonical.is_file() {
return Err(HubError::NotFound(format!(
"Path '{}' is not a regular file",
path.display()
)));
}
Ok(())
}
/// Upload configuration
#[derive(Debug, Clone)]
pub struct UploadConfig {
/// HuggingFace token for authentication (required)
pub hf_token: String,
/// Make repository private
pub private: bool,
/// Create repository if it doesn't exist
pub create_repo: bool,
/// Upload SONA weights separately
pub include_sona_weights: bool,
/// Generate model card automatically
pub auto_model_card: bool,
/// Commit message
pub commit_message: String,
}
impl UploadConfig {
/// Create upload config with token
pub fn new(hf_token: String) -> Self {
Self {
hf_token,
private: false,
create_repo: true,
include_sona_weights: true,
auto_model_card: true,
commit_message: "Upload RuvLTRA model".to_string(),
}
}
/// Set repository visibility
pub fn private(mut self, private: bool) -> Self {
self.private = private;
self
}
/// Set commit message
pub fn commit_message(mut self, message: impl Into<String>) -> Self {
self.commit_message = message.into();
self
}
}
/// Upload progress information
#[derive(Debug, Clone)]
pub struct UploadProgress {
/// Total bytes to upload
pub total_bytes: u64,
/// Bytes uploaded so far
pub uploaded_bytes: u64,
/// Upload speed in bytes/sec
pub speed_bps: f64,
/// Current file being uploaded
pub current_file: String,
/// Upload stage
pub stage: UploadStage,
}
/// Upload stages
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum UploadStage {
/// Preparing upload
Preparing,
/// Creating repository
CreatingRepo,
/// Uploading model file
UploadingModel,
/// Uploading SONA weights
UploadingSona,
/// Uploading model card
UploadingCard,
/// Complete
Complete,
/// Failed
Failed(String),
}
/// Model metadata for upload
#[derive(Debug, Clone)]
pub struct ModelMetadata {
/// Model name
pub name: String,
/// Model description
pub description: Option<String>,
/// Model architecture
pub architecture: String,
/// Number of parameters
pub params_b: f32,
/// Context length
pub context_length: usize,
/// Quantization type
pub quantization: Option<String>,
/// License
pub license: Option<String>,
/// Training datasets
pub datasets: Vec<String>,
/// Tags for discovery
pub tags: Vec<String>,
}
/// Model uploader
pub struct ModelUploader {
config: UploadConfig,
}
impl ModelUploader {
/// Create a new uploader with HF token
pub fn new(hf_token: impl Into<String>) -> Self {
Self {
config: UploadConfig::new(hf_token.into()),
}
}
/// Create uploader with custom config
pub fn with_config(config: UploadConfig) -> Self {
Self { config }
}
/// Upload a model file to HuggingFace Hub
///
/// # Arguments
///
/// * `model_path` - Path to the model file (.gguf)
/// * `repo_id` - HuggingFace repository (e.g., "username/model-name")
/// * `metadata` - Optional model metadata
///
/// # Example
///
/// ```rust,ignore
/// let uploader = ModelUploader::new("hf_token");
/// uploader.upload(
/// "./ruvltra-custom.gguf",
/// "username/ruvltra-custom",
/// Some(metadata),
/// )?;
/// ```
pub fn upload(
&self,
model_path: impl AsRef<Path>,
repo_id: &str,
metadata: Option<ModelMetadata>,
) -> Result<String> {
let model_path = model_path.as_ref();
// SECURITY: Validate repository ID format (prevents CLI injection)
validate_repo_id(repo_id)?;
// SECURITY: Validate and canonicalize file path (prevents path traversal)
validate_upload_path(model_path)?;
// For now, use git-based upload via huggingface-cli
// In production, this would use the HF API
self.upload_via_cli(model_path, repo_id, metadata)
}
/// Upload using huggingface-cli (requires huggingface-cli to be installed)
fn upload_via_cli(
&self,
model_path: &Path,
repo_id: &str,
metadata: Option<ModelMetadata>,
) -> Result<String> {
// Check if huggingface-cli is available
if !self.has_hf_cli() {
return Err(HubError::Config(
"huggingface-cli not found. Install with: pip install huggingface_hub[cli]"
.to_string(),
));
}
// Create repository if needed
if self.config.create_repo {
self.create_repo_cli(repo_id)?;
}
// Upload model file
self.upload_file_cli(model_path, repo_id)?;
// Generate and upload model card if enabled
if self.config.auto_model_card {
if let Some(meta) = metadata {
let card = self.generate_model_card(&meta);
self.upload_model_card_cli(&card, repo_id)?;
}
}
Ok(format!("https://huggingface.co/{}", repo_id))
}
/// Check if huggingface-cli is available
fn has_hf_cli(&self) -> bool {
std::process::Command::new("huggingface-cli")
.arg("--version")
.output()
.map(|o| o.status.success())
.unwrap_or(false)
}
/// Create repository using huggingface-cli
fn create_repo_cli(&self, repo_id: &str) -> Result<()> {
let mut args = vec![
"repo".to_string(),
"create".to_string(),
repo_id.to_string(),
];
if self.config.private {
args.push("--private".to_string());
}
let status = std::process::Command::new("huggingface-cli")
.args(&args)
.env("HF_TOKEN", &self.config.hf_token)
.status()
.map_err(|e| HubError::Network(e.to_string()))?;
if !status.success() && status.code() != Some(1) {
// Exit code 1 might mean repo already exists
return Err(HubError::Network("Failed to create repository".to_string()));
}
Ok(())
}
/// Upload file using huggingface-cli
fn upload_file_cli(&self, file_path: &Path, repo_id: &str) -> Result<()> {
let args = vec![
"upload".to_string(),
repo_id.to_string(),
file_path.to_str().unwrap().to_string(),
"--commit-message".to_string(),
self.config.commit_message.clone(),
];
let status = std::process::Command::new("huggingface-cli")
.args(&args)
.env("HF_TOKEN", &self.config.hf_token)
.status()
.map_err(|e| HubError::Network(e.to_string()))?;
if !status.success() {
return Err(HubError::Network("Failed to upload file".to_string()));
}
Ok(())
}
/// Generate model card from metadata
fn generate_model_card(&self, metadata: &ModelMetadata) -> ModelCard {
use super::model_card::{Framework, License, TaskType};
let mut builder = ModelCardBuilder::new(&metadata.name);
if let Some(desc) = &metadata.description {
builder = builder.description(desc);
}
builder = builder
.task(TaskType::TextGeneration)
.framework(Framework::Gguf)
.architecture(&metadata.architecture)
.parameters((metadata.params_b * 1e9) as u64)
.context_length(metadata.context_length);
if let Some(quant) = &metadata.quantization {
builder = builder.add_tag(quant);
}
if let Some(license) = &metadata.license {
if let Ok(lic) = license.parse() {
builder = builder.license(lic);
}
}
for dataset in &metadata.datasets {
builder = builder.add_dataset(dataset, None);
}
for tag in &metadata.tags {
builder = builder.add_tag(tag);
}
builder.build()
}
/// Upload model card
fn upload_model_card_cli(&self, card: &ModelCard, repo_id: &str) -> Result<()> {
// Write card to temporary file
let temp_dir = std::env::temp_dir();
let card_path = temp_dir.join("README.md");
fs::write(&card_path, card.to_markdown())?;
// Upload README.md
self.upload_file_cli(&card_path, repo_id)?;
// Clean up
let _ = fs::remove_file(&card_path);
Ok(())
}
}
/// Upload error type
#[derive(Debug, thiserror::Error)]
pub enum UploadError {
/// Authentication error
#[error("Authentication failed: {0}")]
Auth(String),
/// Network error
#[error("Network error: {0}")]
Network(String),
/// IO error
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_upload_config() {
let config = UploadConfig::new("test_token".to_string());
assert!(!config.private);
assert!(config.create_repo);
assert!(config.include_sona_weights);
}
#[test]
fn test_upload_config_builder() {
let config = UploadConfig::new("token".to_string())
.private(true)
.commit_message("Custom message");
assert!(config.private);
assert_eq!(config.commit_message, "Custom message");
}
#[test]
fn test_model_metadata() {
let metadata = ModelMetadata {
name: "RuvLTRA Test".to_string(),
description: Some("Test model".to_string()),
architecture: "llama".to_string(),
params_b: 0.5,
context_length: 4096,
quantization: Some("Q4_K_M".to_string()),
license: Some("MIT".to_string()),
datasets: vec!["dataset1".to_string()],
tags: vec!["test".to_string()],
};
assert_eq!(metadata.params_b, 0.5);
assert!(metadata.description.is_some());
}
}