Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
576
vendor/ruvector/crates/ruvllm/src/hub/download.rs
vendored
Normal file
576
vendor/ruvector/crates/ruvllm/src/hub/download.rs
vendored
Normal file
@@ -0,0 +1,576 @@
|
||||
//! Model download functionality with progress tracking and resume support
|
||||
|
||||
use super::progress::{ProgressBar, ProgressStyle};
|
||||
use super::registry::ModelInfo;
|
||||
use super::{default_cache_dir, get_hf_token, HubError, Result};
|
||||
use regex::Regex;
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::fs::{self, File};
|
||||
use std::io::{self, BufWriter, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
// ============================================================================
|
||||
// Security: URL and Input Validation (H-001)
|
||||
// ============================================================================
|
||||
|
||||
/// Allowed domains for HuggingFace downloads
|
||||
const ALLOWED_DOMAINS: &[&str] = &["huggingface.co", "hf.co", "cdn-lfs.huggingface.co"];
|
||||
|
||||
/// Validate URL is from allowed HuggingFace domains
|
||||
fn validate_url(url: &str) -> Result<()> {
|
||||
// Parse the URL to extract the host
|
||||
let url_lower = url.to_lowercase();
|
||||
|
||||
// Check for valid HTTPS scheme
|
||||
if !url_lower.starts_with("https://") {
|
||||
return Err(HubError::InvalidFormat(
|
||||
"Only HTTPS URLs are allowed for downloads".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Extract host from URL
|
||||
let without_scheme = &url[8..]; // Skip "https://"
|
||||
let host_end = without_scheme.find('/').unwrap_or(without_scheme.len());
|
||||
let host = &without_scheme[..host_end];
|
||||
|
||||
// Remove port if present
|
||||
let host = host.split(':').next().unwrap_or(host);
|
||||
|
||||
// Check against allowlist
|
||||
let is_allowed = ALLOWED_DOMAINS
|
||||
.iter()
|
||||
.any(|&domain| host == domain || host.ends_with(&format!(".{}", domain)));
|
||||
|
||||
if !is_allowed {
|
||||
return Err(HubError::InvalidFormat(format!(
|
||||
"URL host '{}' is not in the allowed domains: {:?}",
|
||||
host, ALLOWED_DOMAINS
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate repo_id format (prevents CLI injection)
|
||||
/// Only allows: alphanumeric, /, -, _, .
|
||||
fn validate_repo_id(repo_id: &str) -> Result<()> {
|
||||
// Must contain exactly one slash (user/repo format)
|
||||
let slash_count = repo_id.chars().filter(|&c| c == '/').count();
|
||||
if slash_count != 1 {
|
||||
return Err(HubError::InvalidFormat(
|
||||
"Repository ID must be in format 'username/repo-name'".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Regex: only allow safe characters
|
||||
let valid_pattern = Regex::new(r"^[a-zA-Z0-9][a-zA-Z0-9._-]*/[a-zA-Z0-9][a-zA-Z0-9._-]*$")
|
||||
.expect("Invalid regex pattern");
|
||||
|
||||
if !valid_pattern.is_match(repo_id) {
|
||||
return Err(HubError::InvalidFormat(format!(
|
||||
"Repository ID '{}' contains invalid characters. Only alphanumeric, /, -, _, . are allowed",
|
||||
repo_id
|
||||
)));
|
||||
}
|
||||
|
||||
// Prevent path traversal
|
||||
if repo_id.contains("..") {
|
||||
return Err(HubError::InvalidFormat(
|
||||
"Repository ID cannot contain '..' (path traversal)".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Canonicalize and validate file path to prevent path traversal
|
||||
fn validate_and_canonicalize_path(path: &Path, base_dir: &Path) -> Result<PathBuf> {
|
||||
// Canonicalize both paths
|
||||
let canonical_base = base_dir
|
||||
.canonicalize()
|
||||
.map_err(|e| HubError::Config(format!("Failed to canonicalize base directory: {}", e)))?;
|
||||
|
||||
// Create parent directories if needed, then canonicalize
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
// For new files, canonicalize the parent and append filename
|
||||
let canonical_path = if path.exists() {
|
||||
path.canonicalize()
|
||||
.map_err(|e| HubError::Config(format!("Failed to canonicalize path: {}", e)))?
|
||||
} else if let Some(parent) = path.parent() {
|
||||
let canonical_parent = parent
|
||||
.canonicalize()
|
||||
.map_err(|e| HubError::Config(format!("Failed to canonicalize parent path: {}", e)))?;
|
||||
canonical_parent.join(
|
||||
path.file_name()
|
||||
.ok_or_else(|| HubError::InvalidFormat("Invalid file path".to_string()))?,
|
||||
)
|
||||
} else {
|
||||
return Err(HubError::InvalidFormat("Invalid file path".to_string()));
|
||||
};
|
||||
|
||||
// Ensure the path is within the base directory
|
||||
if !canonical_path.starts_with(&canonical_base) {
|
||||
return Err(HubError::InvalidFormat(format!(
|
||||
"Path '{}' is outside allowed directory '{}'",
|
||||
canonical_path.display(),
|
||||
canonical_base.display()
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(canonical_path)
|
||||
}
|
||||
|
||||
/// Download configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DownloadConfig {
|
||||
/// Target directory for downloads
|
||||
pub cache_dir: PathBuf,
|
||||
/// HuggingFace token for authentication
|
||||
pub hf_token: Option<String>,
|
||||
/// Enable resume for interrupted downloads
|
||||
pub resume: bool,
|
||||
/// Show progress bar
|
||||
pub show_progress: bool,
|
||||
/// Verify checksum after download
|
||||
pub verify_checksum: bool,
|
||||
/// Maximum retry attempts
|
||||
pub max_retries: u32,
|
||||
}
|
||||
|
||||
impl Default for DownloadConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
cache_dir: default_cache_dir(),
|
||||
hf_token: get_hf_token(),
|
||||
resume: true,
|
||||
show_progress: true,
|
||||
verify_checksum: true,
|
||||
max_retries: 3,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Download progress information
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DownloadProgress {
|
||||
/// Total bytes to download
|
||||
pub total_bytes: u64,
|
||||
/// Bytes downloaded so far
|
||||
pub downloaded_bytes: u64,
|
||||
/// Download speed in bytes/sec
|
||||
pub speed_bps: f64,
|
||||
/// Estimated time remaining in seconds
|
||||
pub eta_seconds: f64,
|
||||
/// Current stage
|
||||
pub stage: DownloadStage,
|
||||
}
|
||||
|
||||
/// Download stages
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum DownloadStage {
|
||||
/// Preparing download
|
||||
Preparing,
|
||||
/// Downloading file
|
||||
Downloading,
|
||||
/// Verifying checksum
|
||||
Verifying,
|
||||
/// Complete
|
||||
Complete,
|
||||
/// Failed
|
||||
Failed(String),
|
||||
}
|
||||
|
||||
impl DownloadProgress {
|
||||
/// Calculate progress percentage
|
||||
pub fn percentage(&self) -> f32 {
|
||||
if self.total_bytes == 0 {
|
||||
0.0
|
||||
} else {
|
||||
(self.downloaded_bytes as f64 / self.total_bytes as f64 * 100.0) as f32
|
||||
}
|
||||
}
|
||||
|
||||
/// Format speed as human-readable string
|
||||
pub fn speed_str(&self) -> String {
|
||||
format_bytes_per_sec(self.speed_bps)
|
||||
}
|
||||
|
||||
/// Format ETA as human-readable string
|
||||
pub fn eta_str(&self) -> String {
|
||||
format_duration(self.eta_seconds as u64)
|
||||
}
|
||||
}
|
||||
|
||||
/// Checksum verifier
|
||||
pub struct ChecksumVerifier {
|
||||
hasher: Sha256,
|
||||
bytes_hashed: u64,
|
||||
}
|
||||
|
||||
impl ChecksumVerifier {
|
||||
/// Create a new checksum verifier
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
hasher: Sha256::new(),
|
||||
bytes_hashed: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update with new data
|
||||
pub fn update(&mut self, data: &[u8]) {
|
||||
self.hasher.update(data);
|
||||
self.bytes_hashed += data.len() as u64;
|
||||
}
|
||||
|
||||
/// Finalize and get checksum
|
||||
pub fn finalize(self) -> String {
|
||||
format!("{:x}", self.hasher.finalize())
|
||||
}
|
||||
|
||||
/// Verify against expected checksum
|
||||
pub fn verify(self, expected: &str) -> Result<()> {
|
||||
let actual = self.finalize();
|
||||
if actual == expected {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(HubError::ChecksumMismatch {
|
||||
expected: expected.to_string(),
|
||||
actual,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ChecksumVerifier {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Model downloader
|
||||
pub struct ModelDownloader {
|
||||
config: DownloadConfig,
|
||||
}
|
||||
|
||||
impl ModelDownloader {
|
||||
/// Create a new downloader with default config
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
config: DownloadConfig::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a downloader with custom config
|
||||
pub fn with_config(config: DownloadConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
/// Download a model by ID from the registry
|
||||
pub fn download_by_id(&self, model_id: &str) -> Result<PathBuf> {
|
||||
let registry = super::registry::RuvLtraRegistry::new();
|
||||
let model_info = registry
|
||||
.get(model_id)
|
||||
.ok_or_else(|| HubError::NotFound(model_id.to_string()))?;
|
||||
|
||||
self.download(model_info, None)
|
||||
}
|
||||
|
||||
/// Download a model from ModelInfo
|
||||
pub fn download(&self, model_info: &ModelInfo, target_path: Option<&Path>) -> Result<PathBuf> {
|
||||
// Determine target path
|
||||
let path = if let Some(p) = target_path {
|
||||
p.to_path_buf()
|
||||
} else {
|
||||
self.config.cache_dir.join(&model_info.filename)
|
||||
};
|
||||
|
||||
// SECURITY: Validate and canonicalize path to prevent path traversal
|
||||
let path = validate_and_canonicalize_path(&path, &self.config.cache_dir)?;
|
||||
|
||||
// Check if already downloaded
|
||||
if path.exists() && !self.config.resume {
|
||||
if self.config.verify_checksum {
|
||||
if let Some(checksum) = &model_info.checksum {
|
||||
self.verify_file(&path, checksum)?;
|
||||
}
|
||||
}
|
||||
return Ok(path);
|
||||
}
|
||||
|
||||
// Download the file
|
||||
let url = model_info.download_url();
|
||||
|
||||
// SECURITY: Validate URL is from allowed domains
|
||||
validate_url(&url)?;
|
||||
|
||||
self.download_file(
|
||||
&url,
|
||||
&path,
|
||||
model_info.size_bytes,
|
||||
model_info.checksum.as_deref(),
|
||||
)?;
|
||||
|
||||
Ok(path)
|
||||
}
|
||||
|
||||
/// Download a file from URL
|
||||
fn download_file(
|
||||
&self,
|
||||
url: &str,
|
||||
path: &Path,
|
||||
expected_size: u64,
|
||||
expected_checksum: Option<&str>,
|
||||
) -> Result<()> {
|
||||
// Use curl/wget if available, otherwise fail with helpful message
|
||||
if self.has_curl() {
|
||||
self.download_with_curl(url, path, expected_size, expected_checksum)
|
||||
} else if self.has_wget() {
|
||||
self.download_with_wget(url, path, expected_size, expected_checksum)
|
||||
} else {
|
||||
Err(HubError::Config(
|
||||
"Download requires curl or wget. Please install: brew install curl (macOS) or apt install curl (Linux)"
|
||||
.to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if curl is available
|
||||
fn has_curl(&self) -> bool {
|
||||
std::process::Command::new("which")
|
||||
.arg("curl")
|
||||
.output()
|
||||
.map(|o| o.status.success())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Check if wget is available
|
||||
fn has_wget(&self) -> bool {
|
||||
std::process::Command::new("which")
|
||||
.arg("wget")
|
||||
.output()
|
||||
.map(|o| o.status.success())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Download using curl
|
||||
fn download_with_curl(
|
||||
&self,
|
||||
url: &str,
|
||||
path: &Path,
|
||||
_expected_size: u64,
|
||||
expected_checksum: Option<&str>,
|
||||
) -> Result<()> {
|
||||
let mut args = vec![
|
||||
"-L".to_string(), // Follow redirects
|
||||
"-#".to_string(), // Progress bar
|
||||
"--fail".to_string(), // Fail on HTTP errors
|
||||
];
|
||||
|
||||
// Add resume flag if enabled
|
||||
if self.config.resume && path.exists() {
|
||||
args.push("-C".to_string());
|
||||
args.push("-".to_string()); // Auto-resume
|
||||
}
|
||||
|
||||
// Add auth token if provided
|
||||
if let Some(token) = &self.config.hf_token {
|
||||
args.push("-H".to_string());
|
||||
args.push(format!("Authorization: Bearer {}", token));
|
||||
}
|
||||
|
||||
args.push("-o".to_string());
|
||||
args.push(path.to_str().unwrap().to_string());
|
||||
args.push(url.to_string());
|
||||
|
||||
let status = std::process::Command::new("curl")
|
||||
.args(&args)
|
||||
.status()
|
||||
.map_err(|e| HubError::Network(e.to_string()))?;
|
||||
|
||||
if !status.success() {
|
||||
return Err(HubError::Network(format!(
|
||||
"curl failed with status: {}",
|
||||
status
|
||||
)));
|
||||
}
|
||||
|
||||
// Verify checksum if provided
|
||||
if self.config.verify_checksum {
|
||||
if let Some(checksum) = expected_checksum {
|
||||
self.verify_file(path, checksum)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Download using wget
|
||||
fn download_with_wget(
|
||||
&self,
|
||||
url: &str,
|
||||
path: &Path,
|
||||
_expected_size: u64,
|
||||
expected_checksum: Option<&str>,
|
||||
) -> Result<()> {
|
||||
let mut args = vec![
|
||||
"-q".to_string(), // Quiet
|
||||
"--show-progress".to_string(), // But show progress
|
||||
];
|
||||
|
||||
// Add resume flag if enabled
|
||||
if self.config.resume && path.exists() {
|
||||
args.push("-c".to_string()); // Continue
|
||||
}
|
||||
|
||||
// Add auth token if provided
|
||||
if let Some(token) = &self.config.hf_token {
|
||||
args.push("--header".to_string());
|
||||
args.push(format!("Authorization: Bearer {}", token));
|
||||
}
|
||||
|
||||
args.push("-O".to_string());
|
||||
args.push(path.to_str().unwrap().to_string());
|
||||
args.push(url.to_string());
|
||||
|
||||
let status = std::process::Command::new("wget")
|
||||
.args(&args)
|
||||
.status()
|
||||
.map_err(|e| HubError::Network(e.to_string()))?;
|
||||
|
||||
if !status.success() {
|
||||
return Err(HubError::Network(format!(
|
||||
"wget failed with status: {}",
|
||||
status
|
||||
)));
|
||||
}
|
||||
|
||||
// Verify checksum if provided
|
||||
if self.config.verify_checksum {
|
||||
if let Some(checksum) = expected_checksum {
|
||||
self.verify_file(path, checksum)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Verify file checksum
|
||||
fn verify_file(&self, path: &Path, expected_checksum: &str) -> Result<()> {
|
||||
use std::io::Read;
|
||||
|
||||
let mut file = File::open(path)?;
|
||||
let mut verifier = ChecksumVerifier::new();
|
||||
let mut buffer = [0u8; 8192];
|
||||
|
||||
loop {
|
||||
let n = file.read(&mut buffer)?;
|
||||
if n == 0 {
|
||||
break;
|
||||
}
|
||||
verifier.update(&buffer[..n]);
|
||||
}
|
||||
|
||||
verifier.verify(expected_checksum)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ModelDownloader {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Download error type
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum DownloadError {
|
||||
/// HTTP error
|
||||
#[error("HTTP error: {0}")]
|
||||
Http(String),
|
||||
/// IO error
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] io::Error),
|
||||
/// Checksum mismatch
|
||||
#[error("Checksum verification failed")]
|
||||
ChecksumMismatch,
|
||||
}
|
||||
|
||||
/// Format bytes per second
|
||||
fn format_bytes_per_sec(bps: f64) -> String {
|
||||
const KB: f64 = 1024.0;
|
||||
const MB: f64 = KB * 1024.0;
|
||||
const GB: f64 = MB * 1024.0;
|
||||
|
||||
if bps >= GB {
|
||||
format!("{:.2} GB/s", bps / GB)
|
||||
} else if bps >= MB {
|
||||
format!("{:.2} MB/s", bps / MB)
|
||||
} else if bps >= KB {
|
||||
format!("{:.2} KB/s", bps / KB)
|
||||
} else {
|
||||
format!("{:.0} B/s", bps)
|
||||
}
|
||||
}
|
||||
|
||||
/// Format duration in seconds
|
||||
fn format_duration(secs: u64) -> String {
|
||||
if secs < 60 {
|
||||
format!("{}s", secs)
|
||||
} else if secs < 3600 {
|
||||
format!("{}m {}s", secs / 60, secs % 60)
|
||||
} else {
|
||||
format!("{}h {}m", secs / 3600, (secs % 3600) / 60)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_download_config_default() {
|
||||
let config = DownloadConfig::default();
|
||||
assert!(config.resume);
|
||||
assert!(config.show_progress);
|
||||
assert!(config.verify_checksum);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_download_progress() {
|
||||
let progress = DownloadProgress {
|
||||
total_bytes: 1000,
|
||||
downloaded_bytes: 500,
|
||||
speed_bps: 1024.0 * 1024.0,
|
||||
eta_seconds: 30.0,
|
||||
stage: DownloadStage::Downloading,
|
||||
};
|
||||
|
||||
assert_eq!(progress.percentage(), 50.0);
|
||||
assert!(progress.speed_str().contains("MB/s"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_checksum_verifier() {
|
||||
let mut verifier = ChecksumVerifier::new();
|
||||
verifier.update(b"hello world");
|
||||
let checksum = verifier.finalize();
|
||||
assert!(!checksum.is_empty());
|
||||
assert_eq!(checksum.len(), 64); // SHA256 hex is 64 chars
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_bytes_per_sec() {
|
||||
assert_eq!(format_bytes_per_sec(500.0), "500 B/s");
|
||||
assert_eq!(format_bytes_per_sec(1024.0 * 10.0), "10.00 KB/s");
|
||||
assert_eq!(format_bytes_per_sec(1024.0 * 1024.0 * 5.0), "5.00 MB/s");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_duration() {
|
||||
assert_eq!(format_duration(30), "30s");
|
||||
assert_eq!(format_duration(90), "1m 30s");
|
||||
assert_eq!(format_duration(3700), "1h 1m");
|
||||
}
|
||||
}
|
||||
145
vendor/ruvector/crates/ruvllm/src/hub/mod.rs
vendored
Normal file
145
vendor/ruvector/crates/ruvllm/src/hub/mod.rs
vendored
Normal file
@@ -0,0 +1,145 @@
|
||||
//! HuggingFace Hub integration for RuvLTRA model management
|
||||
//!
|
||||
//! This module provides comprehensive HuggingFace Hub integration for publishing,
|
||||
//! downloading, and managing RuvLTRA models. It supports:
|
||||
//!
|
||||
//! - **Model Upload**: Push GGUF files and SONA weights to HF Hub
|
||||
//! - **Model Download**: Pull models with automatic quantization selection
|
||||
//! - **Model Registry**: Pre-configured RuvLTRA model collection
|
||||
//! - **Progress Tracking**: Visual progress bars with resume support
|
||||
//! - **Integrity Verification**: Checksum validation for downloads
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```rust,ignore
|
||||
//! use ruvllm::hub::{RuvLtraRegistry, ModelDownloader};
|
||||
//!
|
||||
//! // Download a model
|
||||
//! let registry = RuvLtraRegistry::new();
|
||||
//! let model_info = registry.get("ruvltra-small")?;
|
||||
//! let downloader = ModelDownloader::new();
|
||||
//! let path = downloader.download(model_info, None).await?;
|
||||
//!
|
||||
//! // Upload a model
|
||||
//! let uploader = ModelUploader::new("hf_token_here");
|
||||
//! uploader.upload(
|
||||
//! "./my-model.gguf",
|
||||
//! "username/my-ruvltra",
|
||||
//! Some("My custom RuvLTRA model"),
|
||||
//! ).await?;
|
||||
//! ```
|
||||
|
||||
pub mod download;
|
||||
pub mod model_card;
|
||||
pub mod progress;
|
||||
pub mod registry;
|
||||
pub mod upload;
|
||||
|
||||
// Re-exports
|
||||
pub use download::{
|
||||
ChecksumVerifier, DownloadConfig, DownloadError, DownloadProgress, ModelDownloader,
|
||||
};
|
||||
pub use model_card::{
|
||||
DatasetInfo, Framework, License, MetricResult, ModelCard, ModelCardBuilder, TaskType,
|
||||
};
|
||||
pub use progress::{
|
||||
MultiProgress, ProgressBar, ProgressCallback, ProgressIndicator, ProgressStyle,
|
||||
};
|
||||
pub use registry::{
|
||||
get_model_info, HardwareRequirements, ModelInfo, ModelSize, QuantizationLevel, RuvLtraRegistry,
|
||||
};
|
||||
pub use upload::{ModelMetadata, ModelUploader, UploadConfig, UploadError, UploadProgress};
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Result type for hub operations
|
||||
pub type Result<T> = std::result::Result<T, HubError>;
|
||||
|
||||
/// Hub operation errors
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum HubError {
|
||||
/// IO error
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
/// HTTP error
|
||||
#[cfg(feature = "async-runtime")]
|
||||
#[error("HTTP error: {0}")]
|
||||
Http(String),
|
||||
|
||||
/// Authentication error
|
||||
#[error("Authentication failed: {0}")]
|
||||
Auth(String),
|
||||
|
||||
/// Model not found
|
||||
#[error("Model not found: {0}")]
|
||||
NotFound(String),
|
||||
|
||||
/// Checksum mismatch
|
||||
#[error("Checksum verification failed: expected {expected}, got {actual}")]
|
||||
ChecksumMismatch { expected: String, actual: String },
|
||||
|
||||
/// Invalid model format
|
||||
#[error("Invalid model format: {0}")]
|
||||
InvalidFormat(String),
|
||||
|
||||
/// Rate limit exceeded
|
||||
#[error("Rate limit exceeded. Retry after {0} seconds")]
|
||||
RateLimit(u64),
|
||||
|
||||
/// Network error
|
||||
#[error("Network error: {0}")]
|
||||
Network(String),
|
||||
|
||||
/// Parse error
|
||||
#[error("Parse error: {0}")]
|
||||
Parse(String),
|
||||
|
||||
/// Configuration error
|
||||
#[error("Configuration error: {0}")]
|
||||
Config(String),
|
||||
}
|
||||
|
||||
/// Default HuggingFace Hub API endpoint
|
||||
pub const HF_ENDPOINT: &str = "https://huggingface.co";
|
||||
|
||||
/// Default cache directory for downloaded models
|
||||
pub fn default_cache_dir() -> PathBuf {
|
||||
dirs::cache_dir()
|
||||
.unwrap_or_else(|| PathBuf::from("."))
|
||||
.join("huggingface")
|
||||
.join("ruvltra")
|
||||
}
|
||||
|
||||
/// Get HuggingFace token from environment
|
||||
pub fn get_hf_token() -> Option<String> {
|
||||
std::env::var("HF_TOKEN")
|
||||
.or_else(|_| std::env::var("HUGGING_FACE_HUB_TOKEN"))
|
||||
.or_else(|_| std::env::var("HUGGINGFACE_API_KEY"))
|
||||
.ok()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_cache_dir() {
|
||||
let cache_dir = default_cache_dir();
|
||||
assert!(cache_dir.to_string_lossy().contains("huggingface"));
|
||||
assert!(cache_dir.to_string_lossy().contains("ruvltra"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_display() {
|
||||
let err = HubError::NotFound("model-123".to_string());
|
||||
assert_eq!(err.to_string(), "Model not found: model-123");
|
||||
|
||||
let err = HubError::ChecksumMismatch {
|
||||
expected: "abc123".to_string(),
|
||||
actual: "def456".to_string(),
|
||||
};
|
||||
assert!(err.to_string().contains("abc123"));
|
||||
assert!(err.to_string().contains("def456"));
|
||||
}
|
||||
}
|
||||
426
vendor/ruvector/crates/ruvllm/src/hub/model_card.rs
vendored
Normal file
426
vendor/ruvector/crates/ruvllm/src/hub/model_card.rs
vendored
Normal file
@@ -0,0 +1,426 @@
|
||||
//! Model card generation for HuggingFace Hub
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Model task type
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum TaskType {
|
||||
TextGeneration,
|
||||
ConversationalAi,
|
||||
CodeCompletion,
|
||||
QuestionAnswering,
|
||||
Summarization,
|
||||
}
|
||||
|
||||
/// ML framework
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Framework {
|
||||
Gguf,
|
||||
PyTorch,
|
||||
TensorFlow,
|
||||
Onnx,
|
||||
}
|
||||
|
||||
/// Model license
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum License {
|
||||
Mit,
|
||||
Apache20,
|
||||
Gpl30,
|
||||
Bsd3Clause,
|
||||
CreativemlOpenrailM,
|
||||
Llama2,
|
||||
Other(String),
|
||||
}
|
||||
|
||||
impl std::str::FromStr for License {
|
||||
type Err = ();
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"mit" => Ok(Self::Mit),
|
||||
"apache-2.0" | "apache2.0" => Ok(Self::Apache20),
|
||||
"gpl-3.0" | "gpl3.0" => Ok(Self::Gpl30),
|
||||
"bsd-3-clause" => Ok(Self::Bsd3Clause),
|
||||
"creativeml-openrail-m" => Ok(Self::CreativemlOpenrailM),
|
||||
"llama2" => Ok(Self::Llama2),
|
||||
other => Ok(Self::Other(other.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Dataset information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DatasetInfo {
|
||||
/// Dataset name/identifier
|
||||
pub name: String,
|
||||
/// Dataset description
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
/// Metric result
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MetricResult {
|
||||
/// Metric name (e.g., "perplexity", "accuracy")
|
||||
pub name: String,
|
||||
/// Metric value
|
||||
pub value: f64,
|
||||
/// Dataset used for evaluation
|
||||
pub dataset: Option<String>,
|
||||
}
|
||||
|
||||
/// Model card for HuggingFace Hub
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelCard {
|
||||
/// Model name
|
||||
pub name: String,
|
||||
/// Short description
|
||||
pub description: Option<String>,
|
||||
/// Task type
|
||||
pub task: TaskType,
|
||||
/// Framework
|
||||
pub framework: Framework,
|
||||
/// Architecture (e.g., "llama", "qwen2")
|
||||
pub architecture: String,
|
||||
/// Model license
|
||||
pub license: License,
|
||||
/// Number of parameters
|
||||
pub parameters: u64,
|
||||
/// Context window size
|
||||
pub context_length: usize,
|
||||
/// Training datasets
|
||||
pub datasets: Vec<DatasetInfo>,
|
||||
/// Evaluation metrics
|
||||
pub metrics: Vec<MetricResult>,
|
||||
/// Model tags
|
||||
pub tags: Vec<String>,
|
||||
/// Additional metadata
|
||||
pub metadata: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl ModelCard {
|
||||
/// Convert model card to YAML frontmatter + markdown
|
||||
pub fn to_markdown(&self) -> String {
|
||||
let mut content = String::new();
|
||||
|
||||
// YAML frontmatter
|
||||
content.push_str("---\n");
|
||||
content.push_str(&format!("language: en\n"));
|
||||
content.push_str(&format!("license: {}\n", self.license_str()));
|
||||
content.push_str(&format!("library_name: ruvltra\n"));
|
||||
|
||||
if !self.tags.is_empty() {
|
||||
content.push_str("tags:\n");
|
||||
for tag in &self.tags {
|
||||
content.push_str(&format!("- {}\n", tag));
|
||||
}
|
||||
}
|
||||
|
||||
content.push_str("---\n\n");
|
||||
|
||||
// Model description
|
||||
content.push_str(&format!("# {}\n\n", self.name));
|
||||
|
||||
if let Some(desc) = &self.description {
|
||||
content.push_str(&format!("{}\n\n", desc));
|
||||
}
|
||||
|
||||
// Model details
|
||||
content.push_str("## Model Details\n\n");
|
||||
content.push_str(&format!("- **Architecture**: {}\n", self.architecture));
|
||||
content.push_str(&format!(
|
||||
"- **Parameters**: {}\n",
|
||||
format_params(self.parameters)
|
||||
));
|
||||
content.push_str(&format!(
|
||||
"- **Context Length**: {} tokens\n",
|
||||
self.context_length
|
||||
));
|
||||
content.push_str(&format!("- **Framework**: {:?}\n", self.framework));
|
||||
content.push_str(&format!("- **Task**: {:?}\n\n", self.task));
|
||||
|
||||
// Training data
|
||||
if !self.datasets.is_empty() {
|
||||
content.push_str("## Training Data\n\n");
|
||||
for dataset in &self.datasets {
|
||||
content.push_str(&format!("- **{}**", dataset.name));
|
||||
if let Some(desc) = &dataset.description {
|
||||
content.push_str(&format!(": {}", desc));
|
||||
}
|
||||
content.push_str("\n");
|
||||
}
|
||||
content.push_str("\n");
|
||||
}
|
||||
|
||||
// Evaluation metrics
|
||||
if !self.metrics.is_empty() {
|
||||
content.push_str("## Evaluation\n\n");
|
||||
content.push_str("| Metric | Value | Dataset |\n");
|
||||
content.push_str("|--------|-------|----------|\n");
|
||||
for metric in &self.metrics {
|
||||
content.push_str(&format!(
|
||||
"| {} | {:.2} | {} |\n",
|
||||
metric.name,
|
||||
metric.value,
|
||||
metric.dataset.as_deref().unwrap_or("N/A")
|
||||
));
|
||||
}
|
||||
content.push_str("\n");
|
||||
}
|
||||
|
||||
// Usage
|
||||
content.push_str("## Usage\n\n");
|
||||
content.push_str("```bash\n");
|
||||
content.push_str("# Download using ruvllm CLI\n");
|
||||
content.push_str(&format!("ruvllm pull {}\n", self.name.to_lowercase()));
|
||||
content.push_str("```\n\n");
|
||||
|
||||
content.push_str("```rust\n");
|
||||
content.push_str("use ruvllm::hub::ModelDownloader;\n\n");
|
||||
content.push_str("let downloader = ModelDownloader::new();\n");
|
||||
content.push_str(&format!(
|
||||
"let path = downloader.download_by_id(\"{}\")?;\n",
|
||||
self.name.to_lowercase()
|
||||
));
|
||||
content.push_str("```\n\n");
|
||||
|
||||
// Additional metadata
|
||||
if !self.metadata.is_empty() {
|
||||
content.push_str("## Additional Information\n\n");
|
||||
for (key, value) in &self.metadata {
|
||||
content.push_str(&format!("- **{}**: {}\n", key, value));
|
||||
}
|
||||
content.push_str("\n");
|
||||
}
|
||||
|
||||
// Footer
|
||||
content.push_str("---\n\n");
|
||||
content.push_str("*This model card was generated automatically by RuvLLM*\n");
|
||||
|
||||
content
|
||||
}
|
||||
|
||||
/// Get license as string
|
||||
fn license_str(&self) -> &str {
|
||||
match &self.license {
|
||||
License::Mit => "mit",
|
||||
License::Apache20 => "apache-2.0",
|
||||
License::Gpl30 => "gpl-3.0",
|
||||
License::Bsd3Clause => "bsd-3-clause",
|
||||
License::CreativemlOpenrailM => "creativeml-openrail-m",
|
||||
License::Llama2 => "llama2",
|
||||
License::Other(s) => s,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Model card builder
|
||||
pub struct ModelCardBuilder {
|
||||
name: String,
|
||||
description: Option<String>,
|
||||
task: TaskType,
|
||||
framework: Framework,
|
||||
architecture: String,
|
||||
license: License,
|
||||
parameters: u64,
|
||||
context_length: usize,
|
||||
datasets: Vec<DatasetInfo>,
|
||||
metrics: Vec<MetricResult>,
|
||||
tags: Vec<String>,
|
||||
metadata: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl ModelCardBuilder {
|
||||
/// Create a new model card builder
|
||||
pub fn new(name: impl Into<String>) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
description: None,
|
||||
task: TaskType::TextGeneration,
|
||||
framework: Framework::Gguf,
|
||||
architecture: "llama".to_string(),
|
||||
license: License::Mit,
|
||||
parameters: 0,
|
||||
context_length: 4096,
|
||||
datasets: Vec::new(),
|
||||
metrics: Vec::new(),
|
||||
tags: Vec::new(),
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set description
|
||||
pub fn description(mut self, desc: impl Into<String>) -> Self {
|
||||
self.description = Some(desc.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Set task type
|
||||
pub fn task(mut self, task: TaskType) -> Self {
|
||||
self.task = task;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set framework
|
||||
pub fn framework(mut self, framework: Framework) -> Self {
|
||||
self.framework = framework;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set architecture
|
||||
pub fn architecture(mut self, arch: impl Into<String>) -> Self {
|
||||
self.architecture = arch.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// Set license
|
||||
pub fn license(mut self, license: License) -> Self {
|
||||
self.license = license;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set parameter count
|
||||
pub fn parameters(mut self, params: u64) -> Self {
|
||||
self.parameters = params;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set context length
|
||||
pub fn context_length(mut self, length: usize) -> Self {
|
||||
self.context_length = length;
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a dataset
|
||||
pub fn add_dataset(mut self, name: impl Into<String>, desc: Option<String>) -> Self {
|
||||
self.datasets.push(DatasetInfo {
|
||||
name: name.into(),
|
||||
description: desc,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a metric
|
||||
pub fn add_metric(
|
||||
mut self,
|
||||
name: impl Into<String>,
|
||||
value: f64,
|
||||
dataset: Option<String>,
|
||||
) -> Self {
|
||||
self.metrics.push(MetricResult {
|
||||
name: name.into(),
|
||||
value,
|
||||
dataset,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Add a tag
|
||||
pub fn add_tag(mut self, tag: impl Into<String>) -> Self {
|
||||
self.tags.push(tag.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Add metadata
|
||||
pub fn add_metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
|
||||
self.metadata.insert(key.into(), value.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the model card
|
||||
pub fn build(self) -> ModelCard {
|
||||
ModelCard {
|
||||
name: self.name,
|
||||
description: self.description,
|
||||
task: self.task,
|
||||
framework: self.framework,
|
||||
architecture: self.architecture,
|
||||
license: self.license,
|
||||
parameters: self.parameters,
|
||||
context_length: self.context_length,
|
||||
datasets: self.datasets,
|
||||
metrics: self.metrics,
|
||||
tags: self.tags,
|
||||
metadata: self.metadata,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Format parameter count as human-readable string
|
||||
fn format_params(params: u64) -> String {
|
||||
const B: u64 = 1_000_000_000;
|
||||
const M: u64 = 1_000_000;
|
||||
const K: u64 = 1_000;
|
||||
|
||||
if params >= B {
|
||||
format!("{:.1}B", params as f64 / B as f64)
|
||||
} else if params >= M {
|
||||
format!("{:.0}M", params as f64 / M as f64)
|
||||
} else if params >= K {
|
||||
format!("{:.0}K", params as f64 / K as f64)
|
||||
} else {
|
||||
format!("{}", params)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_model_card_builder() {
|
||||
let card = ModelCardBuilder::new("Test Model")
|
||||
.description("A test model")
|
||||
.architecture("llama")
|
||||
.parameters(500_000_000)
|
||||
.context_length(4096)
|
||||
.add_tag("test")
|
||||
.build();
|
||||
|
||||
assert_eq!(card.name, "Test Model");
|
||||
assert_eq!(card.parameters, 500_000_000);
|
||||
assert_eq!(card.tags.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_card_markdown() {
|
||||
let card = ModelCardBuilder::new("RuvLTRA Small")
|
||||
.description("Compact model")
|
||||
.parameters(500_000_000)
|
||||
.add_dataset("dataset1", Some("Training data".to_string()))
|
||||
.add_metric("perplexity", 5.2, Some("test-set".to_string()))
|
||||
.build();
|
||||
|
||||
let markdown = card.to_markdown();
|
||||
assert!(markdown.contains("# RuvLTRA Small"));
|
||||
assert!(markdown.contains("0.5B"));
|
||||
assert!(markdown.contains("dataset1"));
|
||||
assert!(markdown.contains("perplexity"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_params() {
|
||||
assert_eq!(format_params(500), "500");
|
||||
assert_eq!(format_params(5_000), "5K");
|
||||
assert_eq!(format_params(5_000_000), "5M");
|
||||
assert_eq!(format_params(500_000_000), "0.5B");
|
||||
assert_eq!(format_params(3_000_000_000), "3.0B");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_license_from_str() {
|
||||
use std::str::FromStr;
|
||||
|
||||
assert_eq!(License::from_str("mit").unwrap(), License::Mit);
|
||||
assert_eq!(License::from_str("apache-2.0").unwrap(), License::Apache20);
|
||||
|
||||
match License::from_str("custom-license").unwrap() {
|
||||
License::Other(s) => assert_eq!(s, "custom-license"),
|
||||
_ => panic!("Expected Other variant"),
|
||||
}
|
||||
}
|
||||
}
|
||||
298
vendor/ruvector/crates/ruvllm/src/hub/progress.rs
vendored
Normal file
298
vendor/ruvector/crates/ruvllm/src/hub/progress.rs
vendored
Normal file
@@ -0,0 +1,298 @@
|
||||
//! Progress tracking for download and upload operations
|
||||
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
/// Progress bar styles
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ProgressStyle {
|
||||
/// Simple bar: [=====> ] 50%
|
||||
Bar,
|
||||
/// Detailed: [=====> ] 50% (5.2 MB/s, ETA: 30s)
|
||||
Detailed,
|
||||
/// Minimal: 50% complete
|
||||
Minimal,
|
||||
}
|
||||
|
||||
/// Progress indicator for terminal output
|
||||
pub struct ProgressBar {
|
||||
/// Total bytes
|
||||
total: u64,
|
||||
/// Current bytes
|
||||
current: u64,
|
||||
/// Start time
|
||||
start_time: Instant,
|
||||
/// Last update time
|
||||
last_update: Instant,
|
||||
/// Progress style
|
||||
style: ProgressStyle,
|
||||
/// Bar width
|
||||
width: usize,
|
||||
/// Show in terminal
|
||||
enabled: bool,
|
||||
}
|
||||
|
||||
impl ProgressBar {
|
||||
/// Create a new progress bar
|
||||
pub fn new(total: u64) -> Self {
|
||||
Self {
|
||||
total,
|
||||
current: 0,
|
||||
start_time: Instant::now(),
|
||||
last_update: Instant::now(),
|
||||
style: ProgressStyle::Detailed,
|
||||
width: 40,
|
||||
enabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set progress style
|
||||
pub fn with_style(mut self, style: ProgressStyle) -> Self {
|
||||
self.style = style;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set bar width
|
||||
pub fn with_width(mut self, width: usize) -> Self {
|
||||
self.width = width;
|
||||
self
|
||||
}
|
||||
|
||||
/// Enable or disable output
|
||||
pub fn enabled(mut self, enabled: bool) -> Self {
|
||||
self.enabled = enabled;
|
||||
self
|
||||
}
|
||||
|
||||
/// Update progress
|
||||
pub fn update(&mut self, current: u64) {
|
||||
self.current = current;
|
||||
self.last_update = Instant::now();
|
||||
|
||||
if self.enabled {
|
||||
self.render();
|
||||
}
|
||||
}
|
||||
|
||||
/// Increment progress
|
||||
pub fn inc(&mut self, delta: u64) {
|
||||
self.update(self.current + delta);
|
||||
}
|
||||
|
||||
/// Finish progress bar
|
||||
pub fn finish(&mut self) {
|
||||
self.current = self.total;
|
||||
if self.enabled {
|
||||
self.render();
|
||||
println!(); // New line after completion
|
||||
}
|
||||
}
|
||||
|
||||
/// Render progress bar to terminal
|
||||
fn render(&self) {
|
||||
let percentage = if self.total == 0 {
|
||||
0.0
|
||||
} else {
|
||||
(self.current as f64 / self.total as f64) * 100.0
|
||||
};
|
||||
|
||||
match self.style {
|
||||
ProgressStyle::Bar => {
|
||||
let filled = ((percentage / 100.0) * self.width as f64) as usize;
|
||||
let bar = format!(
|
||||
"[{}>{}] {:.0}%",
|
||||
"=".repeat(filled),
|
||||
" ".repeat(self.width.saturating_sub(filled)),
|
||||
percentage
|
||||
);
|
||||
print!("\r{}", bar);
|
||||
}
|
||||
ProgressStyle::Detailed => {
|
||||
let filled = ((percentage / 100.0) * self.width as f64) as usize;
|
||||
let speed = self.calculate_speed();
|
||||
let eta = self.calculate_eta();
|
||||
|
||||
let bar = format!(
|
||||
"[{}>{}] {:.0}% ({}, ETA: {})",
|
||||
"=".repeat(filled),
|
||||
" ".repeat(self.width.saturating_sub(filled)),
|
||||
percentage,
|
||||
format_speed(speed),
|
||||
format_duration(eta)
|
||||
);
|
||||
print!("\r{}", bar);
|
||||
}
|
||||
ProgressStyle::Minimal => {
|
||||
print!("\r{:.0}% complete", percentage);
|
||||
}
|
||||
}
|
||||
|
||||
use std::io::{self, Write};
|
||||
let _ = io::stdout().flush();
|
||||
}
|
||||
|
||||
/// Calculate download/upload speed in bytes/sec
|
||||
fn calculate_speed(&self) -> f64 {
|
||||
let elapsed = self.start_time.elapsed().as_secs_f64();
|
||||
if elapsed > 0.0 {
|
||||
self.current as f64 / elapsed
|
||||
} else {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate estimated time remaining
|
||||
fn calculate_eta(&self) -> Duration {
|
||||
let remaining = self.total.saturating_sub(self.current);
|
||||
let speed = self.calculate_speed();
|
||||
|
||||
if speed > 0.0 {
|
||||
let seconds = remaining as f64 / speed;
|
||||
Duration::from_secs_f64(seconds)
|
||||
} else {
|
||||
Duration::from_secs(0)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Format speed as human-readable string
|
||||
fn format_speed(bps: f64) -> String {
|
||||
const KB: f64 = 1024.0;
|
||||
const MB: f64 = KB * 1024.0;
|
||||
const GB: f64 = MB * 1024.0;
|
||||
|
||||
if bps >= GB {
|
||||
format!("{:.2} GB/s", bps / GB)
|
||||
} else if bps >= MB {
|
||||
format!("{:.2} MB/s", bps / MB)
|
||||
} else if bps >= KB {
|
||||
format!("{:.2} KB/s", bps / KB)
|
||||
} else {
|
||||
format!("{:.0} B/s", bps)
|
||||
}
|
||||
}
|
||||
|
||||
/// Format duration as human-readable string
|
||||
fn format_duration(d: Duration) -> String {
|
||||
let secs = d.as_secs();
|
||||
if secs < 60 {
|
||||
format!("{}s", secs)
|
||||
} else if secs < 3600 {
|
||||
format!("{}m {}s", secs / 60, secs % 60)
|
||||
} else {
|
||||
format!("{}h {}m", secs / 3600, (secs % 3600) / 60)
|
||||
}
|
||||
}
|
||||
|
||||
/// Progress callback function type
|
||||
pub type ProgressCallback = Box<dyn Fn(u64, u64) + Send + Sync>;
|
||||
|
||||
/// Progress indicator trait
|
||||
pub trait ProgressIndicator {
|
||||
/// Update progress
|
||||
fn update(&mut self, current: u64, total: u64);
|
||||
/// Finish progress
|
||||
fn finish(&mut self);
|
||||
}
|
||||
|
||||
impl ProgressIndicator for ProgressBar {
|
||||
fn update(&mut self, current: u64, _total: u64) {
|
||||
self.update(current);
|
||||
}
|
||||
|
||||
fn finish(&mut self) {
|
||||
self.finish();
|
||||
}
|
||||
}
|
||||
|
||||
/// Multi-progress manager for multiple concurrent operations
|
||||
pub struct MultiProgress {
|
||||
bars: Vec<ProgressBar>,
|
||||
}
|
||||
|
||||
impl MultiProgress {
|
||||
/// Create a new multi-progress manager
|
||||
pub fn new() -> Self {
|
||||
Self { bars: Vec::new() }
|
||||
}
|
||||
|
||||
/// Add a progress bar
|
||||
pub fn add(&mut self, bar: ProgressBar) -> usize {
|
||||
let id = self.bars.len();
|
||||
self.bars.push(bar);
|
||||
id
|
||||
}
|
||||
|
||||
/// Update a specific progress bar
|
||||
pub fn update(&mut self, id: usize, current: u64) {
|
||||
if let Some(bar) = self.bars.get_mut(id) {
|
||||
bar.update(current);
|
||||
}
|
||||
}
|
||||
|
||||
/// Finish all progress bars
|
||||
pub fn finish_all(&mut self) {
|
||||
for bar in &mut self.bars {
|
||||
bar.finish();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MultiProgress {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_progress_bar_creation() {
|
||||
let pb = ProgressBar::new(1000);
|
||||
assert_eq!(pb.total, 1000);
|
||||
assert_eq!(pb.current, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_progress_update() {
|
||||
let mut pb = ProgressBar::new(1000).enabled(false);
|
||||
pb.update(500);
|
||||
assert_eq!(pb.current, 500);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_progress_increment() {
|
||||
let mut pb = ProgressBar::new(1000).enabled(false);
|
||||
pb.inc(100);
|
||||
pb.inc(100);
|
||||
assert_eq!(pb.current, 200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_speed() {
|
||||
assert_eq!(format_speed(500.0), "500 B/s");
|
||||
assert_eq!(format_speed(1024.0 * 10.0), "10.00 KB/s");
|
||||
assert_eq!(format_speed(1024.0 * 1024.0 * 5.0), "5.00 MB/s");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_duration() {
|
||||
assert_eq!(format_duration(Duration::from_secs(30)), "30s");
|
||||
assert_eq!(format_duration(Duration::from_secs(90)), "1m 30s");
|
||||
assert_eq!(format_duration(Duration::from_secs(3700)), "1h 1m");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multi_progress() {
|
||||
let mut mp = MultiProgress::new();
|
||||
let id1 = mp.add(ProgressBar::new(100).enabled(false));
|
||||
let id2 = mp.add(ProgressBar::new(200).enabled(false));
|
||||
|
||||
mp.update(id1, 50);
|
||||
mp.update(id2, 100);
|
||||
|
||||
assert_eq!(mp.bars[id1].current, 50);
|
||||
assert_eq!(mp.bars[id2].current, 100);
|
||||
}
|
||||
}
|
||||
443
vendor/ruvector/crates/ruvllm/src/hub/registry.rs
vendored
Normal file
443
vendor/ruvector/crates/ruvllm/src/hub/registry.rs
vendored
Normal file
@@ -0,0 +1,443 @@
|
||||
//! RuvLTRA model registry with pre-configured models
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Model size category
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum ModelSize {
|
||||
/// Tiny models (< 1B parameters)
|
||||
Tiny,
|
||||
/// Small models (0.5B - 1B parameters)
|
||||
Small,
|
||||
/// Medium models (1B - 5B parameters)
|
||||
Medium,
|
||||
/// Large models (5B - 10B parameters)
|
||||
Large,
|
||||
}
|
||||
|
||||
/// Quantization level
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum QuantizationLevel {
|
||||
/// 4-bit quantization (smallest, ~662MB for 0.5B model)
|
||||
Q4,
|
||||
/// 5-bit quantization (balanced)
|
||||
Q5,
|
||||
/// 8-bit quantization (highest quality)
|
||||
Q8,
|
||||
/// FP16 (no quantization)
|
||||
FP16,
|
||||
}
|
||||
|
||||
impl QuantizationLevel {
|
||||
/// Get file size multiplier relative to FP16
|
||||
pub fn size_multiplier(&self) -> f32 {
|
||||
match self {
|
||||
Self::Q4 => 0.25,
|
||||
Self::Q5 => 0.3125,
|
||||
Self::Q8 => 0.5,
|
||||
Self::FP16 => 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get expected memory reduction
|
||||
pub fn memory_reduction(&self) -> f32 {
|
||||
match self {
|
||||
Self::Q4 => 0.75, // 75% reduction
|
||||
Self::Q5 => 0.69, // 69% reduction
|
||||
Self::Q8 => 0.50, // 50% reduction
|
||||
Self::FP16 => 0.0, // No reduction
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Hardware requirements for model execution
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HardwareRequirements {
|
||||
/// Minimum RAM in GB
|
||||
pub min_ram_gb: f32,
|
||||
/// Recommended RAM in GB
|
||||
pub recommended_ram_gb: f32,
|
||||
/// Supports Apple Neural Engine
|
||||
pub supports_ane: bool,
|
||||
/// Supports Metal GPU acceleration
|
||||
pub supports_metal: bool,
|
||||
/// Supports CUDA
|
||||
pub supports_cuda: bool,
|
||||
/// Minimum GPU VRAM in GB (if using GPU)
|
||||
pub min_vram_gb: Option<f32>,
|
||||
}
|
||||
|
||||
/// Model information in the registry
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelInfo {
|
||||
/// Model identifier (e.g., "ruvltra-small")
|
||||
pub id: String,
|
||||
/// Display name
|
||||
pub name: String,
|
||||
/// HuggingFace repository (e.g., "ruvnet/ruvltra-small")
|
||||
pub repo: String,
|
||||
/// Model filename on HF Hub
|
||||
pub filename: String,
|
||||
/// Model size category
|
||||
pub size: ModelSize,
|
||||
/// Quantization level
|
||||
pub quantization: QuantizationLevel,
|
||||
/// File size in bytes
|
||||
pub size_bytes: u64,
|
||||
/// SHA256 checksum
|
||||
pub checksum: Option<String>,
|
||||
/// Number of parameters (in billions)
|
||||
pub params_b: f32,
|
||||
/// Context window size
|
||||
pub context_length: usize,
|
||||
/// Hardware requirements
|
||||
pub hardware: HardwareRequirements,
|
||||
/// Model description
|
||||
pub description: String,
|
||||
/// Whether this is a LoRA adapter
|
||||
pub is_adapter: bool,
|
||||
/// Base model required (for adapters)
|
||||
pub base_model: Option<String>,
|
||||
/// Includes SONA pre-trained weights
|
||||
pub has_sona_weights: bool,
|
||||
}
|
||||
|
||||
impl ModelInfo {
|
||||
/// Get download URL for this model
|
||||
pub fn download_url(&self) -> String {
|
||||
format!(
|
||||
"https://huggingface.co/{}/resolve/main/{}",
|
||||
self.repo, self.filename
|
||||
)
|
||||
}
|
||||
|
||||
/// Get HuggingFace Hub page URL
|
||||
pub fn hub_url(&self) -> String {
|
||||
format!("https://huggingface.co/{}", self.repo)
|
||||
}
|
||||
|
||||
/// Estimate download time in seconds at given speed (MB/s)
|
||||
pub fn estimate_download_time(&self, speed_mbps: f32) -> f32 {
|
||||
let size_mb = self.size_bytes as f32 / (1024.0 * 1024.0);
|
||||
size_mb / speed_mbps
|
||||
}
|
||||
|
||||
/// Check if model fits in available RAM
|
||||
pub fn fits_in_ram(&self, available_gb: f32) -> bool {
|
||||
available_gb >= self.hardware.min_ram_gb
|
||||
}
|
||||
}
|
||||
|
||||
/// RuvLTRA model registry
|
||||
pub struct RuvLtraRegistry {
|
||||
models: HashMap<String, ModelInfo>,
|
||||
}
|
||||
|
||||
impl RuvLtraRegistry {
|
||||
/// Create a new registry with pre-configured models
|
||||
pub fn new() -> Self {
|
||||
let mut models = HashMap::new();
|
||||
|
||||
// RuvLTRA-Small (0.5B) - Q4 quantization
|
||||
models.insert(
|
||||
"ruvltra-small".to_string(),
|
||||
ModelInfo {
|
||||
id: "ruvltra-small".to_string(),
|
||||
name: "RuvLTRA Small (0.5B Q4)".to_string(),
|
||||
repo: "ruv/ruvltra".to_string(),
|
||||
filename: "ruvltra-small-0.5b-q4_k_m.gguf".to_string(),
|
||||
size: ModelSize::Small,
|
||||
quantization: QuantizationLevel::Q4,
|
||||
size_bytes: 662_000_000, // ~662MB
|
||||
checksum: None, // Set after publishing
|
||||
params_b: 0.5,
|
||||
context_length: 4096,
|
||||
hardware: HardwareRequirements {
|
||||
min_ram_gb: 1.0,
|
||||
recommended_ram_gb: 2.0,
|
||||
supports_ane: true,
|
||||
supports_metal: true,
|
||||
supports_cuda: true,
|
||||
min_vram_gb: Some(1.0),
|
||||
},
|
||||
description: "Compact RuvLTRA model optimized for edge devices. \
|
||||
Includes SONA pre-trained weights for adaptive learning."
|
||||
.to_string(),
|
||||
is_adapter: false,
|
||||
base_model: None,
|
||||
has_sona_weights: true,
|
||||
},
|
||||
);
|
||||
|
||||
// RuvLTRA-Small (0.5B) - Q8 quantization
|
||||
models.insert(
|
||||
"ruvltra-small-q8".to_string(),
|
||||
ModelInfo {
|
||||
id: "ruvltra-small-q8".to_string(),
|
||||
name: "RuvLTRA Small (0.5B Q8)".to_string(),
|
||||
repo: "ruv/ruvltra".to_string(),
|
||||
filename: "ruvltra-small-0.5b-q8_0.gguf".to_string(),
|
||||
size: ModelSize::Small,
|
||||
quantization: QuantizationLevel::Q8,
|
||||
size_bytes: 1_324_000_000, // ~1.3GB
|
||||
checksum: None,
|
||||
params_b: 0.5,
|
||||
context_length: 4096,
|
||||
hardware: HardwareRequirements {
|
||||
min_ram_gb: 2.0,
|
||||
recommended_ram_gb: 4.0,
|
||||
supports_ane: true,
|
||||
supports_metal: true,
|
||||
supports_cuda: true,
|
||||
min_vram_gb: Some(2.0),
|
||||
},
|
||||
description: "High-quality Q8 quantization for better accuracy.".to_string(),
|
||||
is_adapter: false,
|
||||
base_model: None,
|
||||
has_sona_weights: true,
|
||||
},
|
||||
);
|
||||
|
||||
// RuvLTRA-Medium (3B) - Q4 quantization
|
||||
models.insert(
|
||||
"ruvltra-medium".to_string(),
|
||||
ModelInfo {
|
||||
id: "ruvltra-medium".to_string(),
|
||||
name: "RuvLTRA Medium (3B Q4)".to_string(),
|
||||
repo: "ruv/ruvltra".to_string(),
|
||||
filename: "ruvltra-medium-1.1b-q4_k_m.gguf".to_string(),
|
||||
size: ModelSize::Medium,
|
||||
quantization: QuantizationLevel::Q4,
|
||||
size_bytes: 2_100_000_000, // ~2.1GB
|
||||
checksum: None,
|
||||
params_b: 3.0,
|
||||
context_length: 8192,
|
||||
hardware: HardwareRequirements {
|
||||
min_ram_gb: 4.0,
|
||||
recommended_ram_gb: 8.0,
|
||||
supports_ane: true,
|
||||
supports_metal: true,
|
||||
supports_cuda: true,
|
||||
min_vram_gb: Some(4.0),
|
||||
},
|
||||
description: "Balanced RuvLTRA model for general-purpose tasks. \
|
||||
Extended context window with SONA learning."
|
||||
.to_string(),
|
||||
is_adapter: false,
|
||||
base_model: None,
|
||||
has_sona_weights: true,
|
||||
},
|
||||
);
|
||||
|
||||
// RuvLTRA-Medium (3B) - Q8 quantization
|
||||
models.insert(
|
||||
"ruvltra-medium-q8".to_string(),
|
||||
ModelInfo {
|
||||
id: "ruvltra-medium-q8".to_string(),
|
||||
name: "RuvLTRA Medium (3B Q8)".to_string(),
|
||||
repo: "ruv/ruvltra".to_string(),
|
||||
filename: "ruvltra-medium-1.1b-q8_0.gguf".to_string(),
|
||||
size: ModelSize::Medium,
|
||||
quantization: QuantizationLevel::Q8,
|
||||
size_bytes: 4_200_000_000, // ~4.2GB
|
||||
checksum: None,
|
||||
params_b: 3.0,
|
||||
context_length: 8192,
|
||||
hardware: HardwareRequirements {
|
||||
min_ram_gb: 6.0,
|
||||
recommended_ram_gb: 12.0,
|
||||
supports_ane: true,
|
||||
supports_metal: true,
|
||||
supports_cuda: true,
|
||||
min_vram_gb: Some(6.0),
|
||||
},
|
||||
description: "High-quality Medium model with Q8 quantization.".to_string(),
|
||||
is_adapter: false,
|
||||
base_model: None,
|
||||
has_sona_weights: true,
|
||||
},
|
||||
);
|
||||
|
||||
// RuvLTRA-Small-Coder (LoRA adapter)
|
||||
models.insert(
|
||||
"ruvltra-small-coder".to_string(),
|
||||
ModelInfo {
|
||||
id: "ruvltra-small-coder".to_string(),
|
||||
name: "RuvLTRA Small Coder (LoRA)".to_string(),
|
||||
repo: "ruv/ruvltra".to_string(),
|
||||
filename: "ruvltra-small-coder-lora.safetensors".to_string(),
|
||||
size: ModelSize::Tiny,
|
||||
quantization: QuantizationLevel::FP16,
|
||||
size_bytes: 50_000_000, // ~50MB (LoRA is small)
|
||||
checksum: None,
|
||||
params_b: 0.05, // Adapter parameters
|
||||
context_length: 4096,
|
||||
hardware: HardwareRequirements {
|
||||
min_ram_gb: 0.1,
|
||||
recommended_ram_gb: 0.5,
|
||||
supports_ane: true,
|
||||
supports_metal: true,
|
||||
supports_cuda: true,
|
||||
min_vram_gb: None,
|
||||
},
|
||||
description: "LoRA adapter for code completion. \
|
||||
Requires ruvltra-small or ruvltra-small-q8 base model."
|
||||
.to_string(),
|
||||
is_adapter: true,
|
||||
base_model: Some("ruvltra-small".to_string()),
|
||||
has_sona_weights: false,
|
||||
},
|
||||
);
|
||||
|
||||
Self { models }
|
||||
}
|
||||
|
||||
/// Get model info by ID
|
||||
pub fn get(&self, id: &str) -> Option<&ModelInfo> {
|
||||
self.models.get(id)
|
||||
}
|
||||
|
||||
/// Get all available models
|
||||
pub fn list_all(&self) -> Vec<&ModelInfo> {
|
||||
self.models.values().collect()
|
||||
}
|
||||
|
||||
/// Get models by size
|
||||
pub fn list_by_size(&self, size: ModelSize) -> Vec<&ModelInfo> {
|
||||
self.models.values().filter(|m| m.size == size).collect()
|
||||
}
|
||||
|
||||
/// Get base models (exclude adapters)
|
||||
pub fn list_base_models(&self) -> Vec<&ModelInfo> {
|
||||
self.models.values().filter(|m| !m.is_adapter).collect()
|
||||
}
|
||||
|
||||
/// Get adapters for a specific base model
|
||||
pub fn list_adapters(&self, base_model: &str) -> Vec<&ModelInfo> {
|
||||
self.models
|
||||
.values()
|
||||
.filter(|m| {
|
||||
m.is_adapter
|
||||
&& m.base_model
|
||||
.as_ref()
|
||||
.map(|b| b == base_model)
|
||||
.unwrap_or(false)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Recommend model based on available RAM
|
||||
pub fn recommend_for_ram(&self, available_gb: f32) -> Option<&ModelInfo> {
|
||||
let mut candidates: Vec<_> = self
|
||||
.models
|
||||
.values()
|
||||
.filter(|m| !m.is_adapter && m.fits_in_ram(available_gb))
|
||||
.collect();
|
||||
|
||||
// Sort by parameters (largest that fits)
|
||||
candidates.sort_by(|a, b| b.params_b.partial_cmp(&a.params_b).unwrap());
|
||||
|
||||
candidates.first().copied()
|
||||
}
|
||||
|
||||
/// Get model IDs
|
||||
pub fn model_ids(&self) -> Vec<String> {
|
||||
self.models.keys().cloned().collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RuvLtraRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Get model info by ID (convenience function)
|
||||
pub fn get_model_info(id: &str) -> Option<ModelInfo> {
|
||||
RuvLtraRegistry::new().get(id).cloned()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_registry_initialization() {
|
||||
let registry = RuvLtraRegistry::new();
|
||||
assert!(registry.get("ruvltra-small").is_some());
|
||||
assert!(registry.get("ruvltra-medium").is_some());
|
||||
assert!(registry.get("nonexistent").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_info() {
|
||||
let registry = RuvLtraRegistry::new();
|
||||
let model = registry.get("ruvltra-small").unwrap();
|
||||
|
||||
assert_eq!(model.params_b, 0.5);
|
||||
assert_eq!(model.quantization, QuantizationLevel::Q4);
|
||||
assert!(model.has_sona_weights);
|
||||
assert!(!model.is_adapter);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_list_by_size() {
|
||||
let registry = RuvLtraRegistry::new();
|
||||
let small_models = registry.list_by_size(ModelSize::Small);
|
||||
assert!(!small_models.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adapters() {
|
||||
let registry = RuvLtraRegistry::new();
|
||||
let adapters = registry.list_adapters("ruvltra-small");
|
||||
assert!(!adapters.is_empty());
|
||||
assert!(adapters[0].is_adapter);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ram_recommendation() {
|
||||
let registry = RuvLtraRegistry::new();
|
||||
|
||||
// Should recommend small model for 2GB
|
||||
let model = registry.recommend_for_ram(2.0);
|
||||
assert!(model.is_some());
|
||||
assert!(model.unwrap().params_b <= 1.0);
|
||||
|
||||
// Should recommend medium model for 8GB
|
||||
let model = registry.recommend_for_ram(8.0);
|
||||
assert!(model.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quantization_multipliers() {
|
||||
assert_eq!(QuantizationLevel::Q4.size_multiplier(), 0.25);
|
||||
assert_eq!(QuantizationLevel::Q8.size_multiplier(), 0.5);
|
||||
assert_eq!(QuantizationLevel::FP16.size_multiplier(), 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_urls() {
|
||||
let registry = RuvLtraRegistry::new();
|
||||
let model = registry.get("ruvltra-small").unwrap();
|
||||
|
||||
let url = model.download_url();
|
||||
assert!(url.contains("huggingface.co"));
|
||||
assert!(url.contains("ruv/ruvltra"));
|
||||
assert!(url.contains(".gguf"));
|
||||
|
||||
let hub_url = model.hub_url();
|
||||
assert_eq!(hub_url, "https://huggingface.co/ruv/ruvltra");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_download_time_estimation() {
|
||||
let registry = RuvLtraRegistry::new();
|
||||
let model = registry.get("ruvltra-small").unwrap();
|
||||
|
||||
// At 10 MB/s, should take ~66 seconds
|
||||
let time = model.estimate_download_time(10.0);
|
||||
assert!(time > 60.0 && time < 70.0);
|
||||
}
|
||||
}
|
||||
440
vendor/ruvector/crates/ruvllm/src/hub/upload.rs
vendored
Normal file
440
vendor/ruvector/crates/ruvllm/src/hub/upload.rs
vendored
Normal file
@@ -0,0 +1,440 @@
|
||||
//! Model upload functionality for publishing to HuggingFace Hub
|
||||
|
||||
use super::model_card::{ModelCard, ModelCardBuilder};
|
||||
use super::{get_hf_token, HubError, Result};
|
||||
use regex::Regex;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
// ============================================================================
|
||||
// Security: Input Validation (H-002)
|
||||
// ============================================================================
|
||||
|
||||
/// Validate repo_id format (prevents CLI injection)
|
||||
/// Only allows: alphanumeric, /, -, _, .
|
||||
fn validate_repo_id(repo_id: &str) -> Result<()> {
|
||||
// Must contain exactly one slash (user/repo format)
|
||||
let slash_count = repo_id.chars().filter(|&c| c == '/').count();
|
||||
if slash_count != 1 {
|
||||
return Err(HubError::InvalidFormat(
|
||||
"Repository ID must be in format 'username/repo-name'".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Regex: only allow safe characters
|
||||
let valid_pattern = Regex::new(r"^[a-zA-Z0-9][a-zA-Z0-9._-]*/[a-zA-Z0-9][a-zA-Z0-9._-]*$")
|
||||
.expect("Invalid regex pattern");
|
||||
|
||||
if !valid_pattern.is_match(repo_id) {
|
||||
return Err(HubError::InvalidFormat(format!(
|
||||
"Repository ID '{}' contains invalid characters. Only alphanumeric, /, -, _, . are allowed",
|
||||
repo_id
|
||||
)));
|
||||
}
|
||||
|
||||
// Prevent path traversal
|
||||
if repo_id.contains("..") {
|
||||
return Err(HubError::InvalidFormat(
|
||||
"Repository ID cannot contain '..' (path traversal)".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Prevent shell metacharacters that could be used for injection
|
||||
let dangerous_chars = [
|
||||
'`', '$', '(', ')', ';', '&', '|', '<', '>', '\n', '\r', '"', '\'', '\\',
|
||||
];
|
||||
for c in dangerous_chars {
|
||||
if repo_id.contains(c) {
|
||||
return Err(HubError::InvalidFormat(format!(
|
||||
"Repository ID cannot contain shell metacharacter '{}'",
|
||||
c
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate file path for upload (prevents path traversal)
|
||||
fn validate_upload_path(path: &Path) -> Result<()> {
|
||||
let path_str = path.to_string_lossy();
|
||||
|
||||
// Prevent path traversal
|
||||
if path_str.contains("..") {
|
||||
return Err(HubError::InvalidFormat(
|
||||
"File path cannot contain '..' (path traversal)".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Canonicalize to resolve any symlinks and verify it exists
|
||||
let canonical = path.canonicalize().map_err(|e| {
|
||||
HubError::NotFound(format!("Cannot resolve path '{}': {}", path.display(), e))
|
||||
})?;
|
||||
|
||||
// Verify the file exists and is a regular file
|
||||
if !canonical.is_file() {
|
||||
return Err(HubError::NotFound(format!(
|
||||
"Path '{}' is not a regular file",
|
||||
path.display()
|
||||
)));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Upload configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UploadConfig {
|
||||
/// HuggingFace token for authentication (required)
|
||||
pub hf_token: String,
|
||||
/// Make repository private
|
||||
pub private: bool,
|
||||
/// Create repository if it doesn't exist
|
||||
pub create_repo: bool,
|
||||
/// Upload SONA weights separately
|
||||
pub include_sona_weights: bool,
|
||||
/// Generate model card automatically
|
||||
pub auto_model_card: bool,
|
||||
/// Commit message
|
||||
pub commit_message: String,
|
||||
}
|
||||
|
||||
impl UploadConfig {
|
||||
/// Create upload config with token
|
||||
pub fn new(hf_token: String) -> Self {
|
||||
Self {
|
||||
hf_token,
|
||||
private: false,
|
||||
create_repo: true,
|
||||
include_sona_weights: true,
|
||||
auto_model_card: true,
|
||||
commit_message: "Upload RuvLTRA model".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set repository visibility
|
||||
pub fn private(mut self, private: bool) -> Self {
|
||||
self.private = private;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set commit message
|
||||
pub fn commit_message(mut self, message: impl Into<String>) -> Self {
|
||||
self.commit_message = message.into();
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Upload progress information
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UploadProgress {
|
||||
/// Total bytes to upload
|
||||
pub total_bytes: u64,
|
||||
/// Bytes uploaded so far
|
||||
pub uploaded_bytes: u64,
|
||||
/// Upload speed in bytes/sec
|
||||
pub speed_bps: f64,
|
||||
/// Current file being uploaded
|
||||
pub current_file: String,
|
||||
/// Upload stage
|
||||
pub stage: UploadStage,
|
||||
}
|
||||
|
||||
/// Upload stages
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum UploadStage {
|
||||
/// Preparing upload
|
||||
Preparing,
|
||||
/// Creating repository
|
||||
CreatingRepo,
|
||||
/// Uploading model file
|
||||
UploadingModel,
|
||||
/// Uploading SONA weights
|
||||
UploadingSona,
|
||||
/// Uploading model card
|
||||
UploadingCard,
|
||||
/// Complete
|
||||
Complete,
|
||||
/// Failed
|
||||
Failed(String),
|
||||
}
|
||||
|
||||
/// Model metadata for upload
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelMetadata {
|
||||
/// Model name
|
||||
pub name: String,
|
||||
/// Model description
|
||||
pub description: Option<String>,
|
||||
/// Model architecture
|
||||
pub architecture: String,
|
||||
/// Number of parameters
|
||||
pub params_b: f32,
|
||||
/// Context length
|
||||
pub context_length: usize,
|
||||
/// Quantization type
|
||||
pub quantization: Option<String>,
|
||||
/// License
|
||||
pub license: Option<String>,
|
||||
/// Training datasets
|
||||
pub datasets: Vec<String>,
|
||||
/// Tags for discovery
|
||||
pub tags: Vec<String>,
|
||||
}
|
||||
|
||||
/// Model uploader
|
||||
pub struct ModelUploader {
|
||||
config: UploadConfig,
|
||||
}
|
||||
|
||||
impl ModelUploader {
|
||||
/// Create a new uploader with HF token
|
||||
pub fn new(hf_token: impl Into<String>) -> Self {
|
||||
Self {
|
||||
config: UploadConfig::new(hf_token.into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create uploader with custom config
|
||||
pub fn with_config(config: UploadConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
/// Upload a model file to HuggingFace Hub
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `model_path` - Path to the model file (.gguf)
|
||||
/// * `repo_id` - HuggingFace repository (e.g., "username/model-name")
|
||||
/// * `metadata` - Optional model metadata
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// let uploader = ModelUploader::new("hf_token");
|
||||
/// uploader.upload(
|
||||
/// "./ruvltra-custom.gguf",
|
||||
/// "username/ruvltra-custom",
|
||||
/// Some(metadata),
|
||||
/// )?;
|
||||
/// ```
|
||||
pub fn upload(
|
||||
&self,
|
||||
model_path: impl AsRef<Path>,
|
||||
repo_id: &str,
|
||||
metadata: Option<ModelMetadata>,
|
||||
) -> Result<String> {
|
||||
let model_path = model_path.as_ref();
|
||||
|
||||
// SECURITY: Validate repository ID format (prevents CLI injection)
|
||||
validate_repo_id(repo_id)?;
|
||||
|
||||
// SECURITY: Validate and canonicalize file path (prevents path traversal)
|
||||
validate_upload_path(model_path)?;
|
||||
|
||||
// For now, use git-based upload via huggingface-cli
|
||||
// In production, this would use the HF API
|
||||
self.upload_via_cli(model_path, repo_id, metadata)
|
||||
}
|
||||
|
||||
/// Upload using huggingface-cli (requires huggingface-cli to be installed)
|
||||
fn upload_via_cli(
|
||||
&self,
|
||||
model_path: &Path,
|
||||
repo_id: &str,
|
||||
metadata: Option<ModelMetadata>,
|
||||
) -> Result<String> {
|
||||
// Check if huggingface-cli is available
|
||||
if !self.has_hf_cli() {
|
||||
return Err(HubError::Config(
|
||||
"huggingface-cli not found. Install with: pip install huggingface_hub[cli]"
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Create repository if needed
|
||||
if self.config.create_repo {
|
||||
self.create_repo_cli(repo_id)?;
|
||||
}
|
||||
|
||||
// Upload model file
|
||||
self.upload_file_cli(model_path, repo_id)?;
|
||||
|
||||
// Generate and upload model card if enabled
|
||||
if self.config.auto_model_card {
|
||||
if let Some(meta) = metadata {
|
||||
let card = self.generate_model_card(&meta);
|
||||
self.upload_model_card_cli(&card, repo_id)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(format!("https://huggingface.co/{}", repo_id))
|
||||
}
|
||||
|
||||
/// Check if huggingface-cli is available
|
||||
fn has_hf_cli(&self) -> bool {
|
||||
std::process::Command::new("huggingface-cli")
|
||||
.arg("--version")
|
||||
.output()
|
||||
.map(|o| o.status.success())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Create repository using huggingface-cli
|
||||
fn create_repo_cli(&self, repo_id: &str) -> Result<()> {
|
||||
let mut args = vec![
|
||||
"repo".to_string(),
|
||||
"create".to_string(),
|
||||
repo_id.to_string(),
|
||||
];
|
||||
|
||||
if self.config.private {
|
||||
args.push("--private".to_string());
|
||||
}
|
||||
|
||||
let status = std::process::Command::new("huggingface-cli")
|
||||
.args(&args)
|
||||
.env("HF_TOKEN", &self.config.hf_token)
|
||||
.status()
|
||||
.map_err(|e| HubError::Network(e.to_string()))?;
|
||||
|
||||
if !status.success() && status.code() != Some(1) {
|
||||
// Exit code 1 might mean repo already exists
|
||||
return Err(HubError::Network("Failed to create repository".to_string()));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Upload file using huggingface-cli
|
||||
fn upload_file_cli(&self, file_path: &Path, repo_id: &str) -> Result<()> {
|
||||
let args = vec![
|
||||
"upload".to_string(),
|
||||
repo_id.to_string(),
|
||||
file_path.to_str().unwrap().to_string(),
|
||||
"--commit-message".to_string(),
|
||||
self.config.commit_message.clone(),
|
||||
];
|
||||
|
||||
let status = std::process::Command::new("huggingface-cli")
|
||||
.args(&args)
|
||||
.env("HF_TOKEN", &self.config.hf_token)
|
||||
.status()
|
||||
.map_err(|e| HubError::Network(e.to_string()))?;
|
||||
|
||||
if !status.success() {
|
||||
return Err(HubError::Network("Failed to upload file".to_string()));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Generate model card from metadata
|
||||
fn generate_model_card(&self, metadata: &ModelMetadata) -> ModelCard {
|
||||
use super::model_card::{Framework, License, TaskType};
|
||||
|
||||
let mut builder = ModelCardBuilder::new(&metadata.name);
|
||||
|
||||
if let Some(desc) = &metadata.description {
|
||||
builder = builder.description(desc);
|
||||
}
|
||||
|
||||
builder = builder
|
||||
.task(TaskType::TextGeneration)
|
||||
.framework(Framework::Gguf)
|
||||
.architecture(&metadata.architecture)
|
||||
.parameters((metadata.params_b * 1e9) as u64)
|
||||
.context_length(metadata.context_length);
|
||||
|
||||
if let Some(quant) = &metadata.quantization {
|
||||
builder = builder.add_tag(quant);
|
||||
}
|
||||
|
||||
if let Some(license) = &metadata.license {
|
||||
if let Ok(lic) = license.parse() {
|
||||
builder = builder.license(lic);
|
||||
}
|
||||
}
|
||||
|
||||
for dataset in &metadata.datasets {
|
||||
builder = builder.add_dataset(dataset, None);
|
||||
}
|
||||
|
||||
for tag in &metadata.tags {
|
||||
builder = builder.add_tag(tag);
|
||||
}
|
||||
|
||||
builder.build()
|
||||
}
|
||||
|
||||
/// Upload model card
|
||||
fn upload_model_card_cli(&self, card: &ModelCard, repo_id: &str) -> Result<()> {
|
||||
// Write card to temporary file
|
||||
let temp_dir = std::env::temp_dir();
|
||||
let card_path = temp_dir.join("README.md");
|
||||
fs::write(&card_path, card.to_markdown())?;
|
||||
|
||||
// Upload README.md
|
||||
self.upload_file_cli(&card_path, repo_id)?;
|
||||
|
||||
// Clean up
|
||||
let _ = fs::remove_file(&card_path);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Upload error type
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum UploadError {
|
||||
/// Authentication error
|
||||
#[error("Authentication failed: {0}")]
|
||||
Auth(String),
|
||||
/// Network error
|
||||
#[error("Network error: {0}")]
|
||||
Network(String),
|
||||
/// IO error
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_upload_config() {
|
||||
let config = UploadConfig::new("test_token".to_string());
|
||||
assert!(!config.private);
|
||||
assert!(config.create_repo);
|
||||
assert!(config.include_sona_weights);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_upload_config_builder() {
|
||||
let config = UploadConfig::new("token".to_string())
|
||||
.private(true)
|
||||
.commit_message("Custom message");
|
||||
|
||||
assert!(config.private);
|
||||
assert_eq!(config.commit_message, "Custom message");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_metadata() {
|
||||
let metadata = ModelMetadata {
|
||||
name: "RuvLTRA Test".to_string(),
|
||||
description: Some("Test model".to_string()),
|
||||
architecture: "llama".to_string(),
|
||||
params_b: 0.5,
|
||||
context_length: 4096,
|
||||
quantization: Some("Q4_K_M".to_string()),
|
||||
license: Some("MIT".to_string()),
|
||||
datasets: vec!["dataset1".to_string()],
|
||||
tags: vec!["test".to_string()],
|
||||
};
|
||||
|
||||
assert_eq!(metadata.params_b, 0.5);
|
||||
assert!(metadata.description.is_some());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user