212 lines
6.1 KiB
Rust
212 lines
6.1 KiB
Rust
//! Model download command implementation
|
|
//!
|
|
//! Downloads models from HuggingFace Hub with progress indication,
|
|
//! supporting various quantization formats optimized for Apple Silicon.
|
|
|
|
use anyhow::{Context, Result};
|
|
use bytesize::ByteSize;
|
|
use colored::Colorize;
|
|
use console::style;
|
|
use hf_hub::api::tokio::Api;
|
|
use hf_hub::{Repo, RepoType};
|
|
use indicatif::{ProgressBar, ProgressStyle};
|
|
use std::path::{Path, PathBuf};
|
|
|
|
use crate::models::{get_model, resolve_model_id, QuantPreset};
|
|
|
|
/// Run the download command
|
|
pub async fn run(
|
|
model: &str,
|
|
quantization: &str,
|
|
force: bool,
|
|
revision: Option<&str>,
|
|
cache_dir: &str,
|
|
) -> Result<()> {
|
|
let model_id = resolve_model_id(model);
|
|
let quant = QuantPreset::from_str(quantization)
|
|
.ok_or_else(|| anyhow::anyhow!("Invalid quantization format: {}", quantization))?;
|
|
|
|
println!();
|
|
println!(
|
|
"{} {} ({})",
|
|
style("Downloading:").bold().cyan(),
|
|
model_id,
|
|
quant
|
|
);
|
|
println!();
|
|
|
|
// Get model info if available
|
|
if let Some(model_def) = get_model(model) {
|
|
println!(" {} {}", "Name:".dimmed(), model_def.name);
|
|
println!(" {} {}", "Architecture:".dimmed(), model_def.architecture);
|
|
println!(" {} {}B", "Parameters:".dimmed(), model_def.params_b);
|
|
println!(
|
|
" {} ~{:.1} GB",
|
|
"Est. Memory:".dimmed(),
|
|
quant.estimate_memory_gb(model_def.params_b)
|
|
);
|
|
println!();
|
|
}
|
|
|
|
// Initialize HuggingFace API
|
|
let api = Api::new().context("Failed to initialize HuggingFace API")?;
|
|
|
|
// Create repo reference
|
|
let repo = if let Some(rev) = revision {
|
|
api.repo(Repo::with_revision(
|
|
model_id.clone(),
|
|
RepoType::Model,
|
|
rev.to_string(),
|
|
))
|
|
} else {
|
|
api.repo(Repo::new(model_id.clone(), RepoType::Model))
|
|
};
|
|
|
|
// Determine files to download
|
|
let files_to_download = get_files_to_download(&model_id, quant);
|
|
|
|
// Create cache directory
|
|
let model_cache_dir = PathBuf::from(cache_dir).join("models").join(&model_id);
|
|
tokio::fs::create_dir_all(&model_cache_dir)
|
|
.await
|
|
.context("Failed to create cache directory")?;
|
|
|
|
// Download each file
|
|
for file_name in &files_to_download {
|
|
let target_path = model_cache_dir.join(file_name);
|
|
|
|
// Check if file exists
|
|
if target_path.exists() && !force {
|
|
let size = tokio::fs::metadata(&target_path).await?.len();
|
|
println!(
|
|
" {} {} ({})",
|
|
style("Cached:").green(),
|
|
file_name,
|
|
ByteSize(size)
|
|
);
|
|
continue;
|
|
}
|
|
|
|
println!(" {} {}", style("Downloading:").yellow(), file_name);
|
|
|
|
// Download with progress
|
|
let downloaded_path = download_with_progress(&repo, file_name).await?;
|
|
|
|
// Copy to cache directory
|
|
tokio::fs::copy(&downloaded_path, &target_path)
|
|
.await
|
|
.context("Failed to copy file to cache")?;
|
|
|
|
let size = tokio::fs::metadata(&target_path).await?.len();
|
|
println!(
|
|
" {} {} ({})",
|
|
style("Downloaded:").green(),
|
|
file_name,
|
|
ByteSize(size)
|
|
);
|
|
}
|
|
|
|
println!();
|
|
println!(
|
|
"{} Model ready at: {}",
|
|
style("Success!").green().bold(),
|
|
model_cache_dir.display()
|
|
);
|
|
println!();
|
|
|
|
// Print usage hint
|
|
println!("{}", "Quick start:".bold());
|
|
println!(" ruvllm chat {}", model);
|
|
println!(" ruvllm serve {}", model);
|
|
println!();
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Download a file with progress indication
|
|
async fn download_with_progress(
|
|
repo: &hf_hub::api::tokio::ApiRepo,
|
|
file_name: &str,
|
|
) -> Result<PathBuf> {
|
|
// Create progress bar
|
|
let pb = ProgressBar::new(100);
|
|
pb.set_style(
|
|
ProgressStyle::default_bar()
|
|
.template(" [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})")
|
|
.unwrap()
|
|
.progress_chars("#>-"),
|
|
);
|
|
|
|
// Download file
|
|
let path = repo
|
|
.get(file_name)
|
|
.await
|
|
.context(format!("Failed to download {}", file_name))?;
|
|
|
|
pb.finish_and_clear();
|
|
|
|
Ok(path)
|
|
}
|
|
|
|
/// Get list of files to download for a model and quantization
|
|
fn get_files_to_download(model_id: &str, quant: QuantPreset) -> Vec<String> {
|
|
let mut files = vec![
|
|
"tokenizer.json".to_string(),
|
|
"tokenizer_config.json".to_string(),
|
|
"config.json".to_string(),
|
|
];
|
|
|
|
// Add model weights based on quantization
|
|
if model_id.contains("GGUF") || quant != QuantPreset::None {
|
|
// Look for GGUF files
|
|
files.push(format!("*{}", quant.gguf_suffix()));
|
|
} else {
|
|
// SafeTensors format
|
|
files.push("model.safetensors".to_string());
|
|
}
|
|
|
|
// Add special tokens and chat template if available
|
|
files.push("special_tokens_map.json".to_string());
|
|
files.push("generation_config.json".to_string());
|
|
|
|
files
|
|
}
|
|
|
|
/// Check if a model is already downloaded
|
|
pub async fn is_model_downloaded(model: &str, cache_dir: &str) -> bool {
|
|
let model_id = resolve_model_id(model);
|
|
let model_cache_dir = PathBuf::from(cache_dir).join("models").join(&model_id);
|
|
|
|
// Check for tokenizer and at least one model file
|
|
let tokenizer_exists = model_cache_dir.join("tokenizer.json").exists();
|
|
let has_weights = tokio::fs::read_dir(&model_cache_dir)
|
|
.await
|
|
.ok()
|
|
.map(|mut dir| {
|
|
use futures::StreamExt;
|
|
// Simplified check - just see if directory exists and has files
|
|
true
|
|
})
|
|
.unwrap_or(false);
|
|
|
|
tokenizer_exists && has_weights
|
|
}
|
|
|
|
/// Get the path to a downloaded model
|
|
pub fn get_model_path(model: &str, cache_dir: &str) -> PathBuf {
|
|
let model_id = resolve_model_id(model);
|
|
PathBuf::from(cache_dir).join("models").join(&model_id)
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_files_to_download() {
|
|
let files = get_files_to_download("test/model", QuantPreset::Q4K);
|
|
assert!(files.contains(&"tokenizer.json".to_string()));
|
|
assert!(files.iter().any(|f| f.contains("Q4_K_M")));
|
|
}
|
|
}
|