Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
211
vendor/ruvector/crates/ruvllm-cli/src/commands/download.rs
vendored
Normal file
211
vendor/ruvector/crates/ruvllm-cli/src/commands/download.rs
vendored
Normal file
@@ -0,0 +1,211 @@
|
||||
//! Model download command implementation
|
||||
//!
|
||||
//! Downloads models from HuggingFace Hub with progress indication,
|
||||
//! supporting various quantization formats optimized for Apple Silicon.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use bytesize::ByteSize;
|
||||
use colored::Colorize;
|
||||
use console::style;
|
||||
use hf_hub::api::tokio::Api;
|
||||
use hf_hub::{Repo, RepoType};
|
||||
use indicatif::{ProgressBar, ProgressStyle};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use crate::models::{get_model, resolve_model_id, QuantPreset};
|
||||
|
||||
/// Run the download command
|
||||
pub async fn run(
|
||||
model: &str,
|
||||
quantization: &str,
|
||||
force: bool,
|
||||
revision: Option<&str>,
|
||||
cache_dir: &str,
|
||||
) -> Result<()> {
|
||||
let model_id = resolve_model_id(model);
|
||||
let quant = QuantPreset::from_str(quantization)
|
||||
.ok_or_else(|| anyhow::anyhow!("Invalid quantization format: {}", quantization))?;
|
||||
|
||||
println!();
|
||||
println!(
|
||||
"{} {} ({})",
|
||||
style("Downloading:").bold().cyan(),
|
||||
model_id,
|
||||
quant
|
||||
);
|
||||
println!();
|
||||
|
||||
// Get model info if available
|
||||
if let Some(model_def) = get_model(model) {
|
||||
println!(" {} {}", "Name:".dimmed(), model_def.name);
|
||||
println!(" {} {}", "Architecture:".dimmed(), model_def.architecture);
|
||||
println!(" {} {}B", "Parameters:".dimmed(), model_def.params_b);
|
||||
println!(
|
||||
" {} ~{:.1} GB",
|
||||
"Est. Memory:".dimmed(),
|
||||
quant.estimate_memory_gb(model_def.params_b)
|
||||
);
|
||||
println!();
|
||||
}
|
||||
|
||||
// Initialize HuggingFace API
|
||||
let api = Api::new().context("Failed to initialize HuggingFace API")?;
|
||||
|
||||
// Create repo reference
|
||||
let repo = if let Some(rev) = revision {
|
||||
api.repo(Repo::with_revision(
|
||||
model_id.clone(),
|
||||
RepoType::Model,
|
||||
rev.to_string(),
|
||||
))
|
||||
} else {
|
||||
api.repo(Repo::new(model_id.clone(), RepoType::Model))
|
||||
};
|
||||
|
||||
// Determine files to download
|
||||
let files_to_download = get_files_to_download(&model_id, quant);
|
||||
|
||||
// Create cache directory
|
||||
let model_cache_dir = PathBuf::from(cache_dir).join("models").join(&model_id);
|
||||
tokio::fs::create_dir_all(&model_cache_dir)
|
||||
.await
|
||||
.context("Failed to create cache directory")?;
|
||||
|
||||
// Download each file
|
||||
for file_name in &files_to_download {
|
||||
let target_path = model_cache_dir.join(file_name);
|
||||
|
||||
// Check if file exists
|
||||
if target_path.exists() && !force {
|
||||
let size = tokio::fs::metadata(&target_path).await?.len();
|
||||
println!(
|
||||
" {} {} ({})",
|
||||
style("Cached:").green(),
|
||||
file_name,
|
||||
ByteSize(size)
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
println!(" {} {}", style("Downloading:").yellow(), file_name);
|
||||
|
||||
// Download with progress
|
||||
let downloaded_path = download_with_progress(&repo, file_name).await?;
|
||||
|
||||
// Copy to cache directory
|
||||
tokio::fs::copy(&downloaded_path, &target_path)
|
||||
.await
|
||||
.context("Failed to copy file to cache")?;
|
||||
|
||||
let size = tokio::fs::metadata(&target_path).await?.len();
|
||||
println!(
|
||||
" {} {} ({})",
|
||||
style("Downloaded:").green(),
|
||||
file_name,
|
||||
ByteSize(size)
|
||||
);
|
||||
}
|
||||
|
||||
println!();
|
||||
println!(
|
||||
"{} Model ready at: {}",
|
||||
style("Success!").green().bold(),
|
||||
model_cache_dir.display()
|
||||
);
|
||||
println!();
|
||||
|
||||
// Print usage hint
|
||||
println!("{}", "Quick start:".bold());
|
||||
println!(" ruvllm chat {}", model);
|
||||
println!(" ruvllm serve {}", model);
|
||||
println!();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Download a file with progress indication
|
||||
async fn download_with_progress(
|
||||
repo: &hf_hub::api::tokio::ApiRepo,
|
||||
file_name: &str,
|
||||
) -> Result<PathBuf> {
|
||||
// Create progress bar
|
||||
let pb = ProgressBar::new(100);
|
||||
pb.set_style(
|
||||
ProgressStyle::default_bar()
|
||||
.template(" [{elapsed_precise}] [{bar:40.cyan/blue}] {bytes}/{total_bytes} ({eta})")
|
||||
.unwrap()
|
||||
.progress_chars("#>-"),
|
||||
);
|
||||
|
||||
// Download file
|
||||
let path = repo
|
||||
.get(file_name)
|
||||
.await
|
||||
.context(format!("Failed to download {}", file_name))?;
|
||||
|
||||
pb.finish_and_clear();
|
||||
|
||||
Ok(path)
|
||||
}
|
||||
|
||||
/// Get list of files to download for a model and quantization
|
||||
fn get_files_to_download(model_id: &str, quant: QuantPreset) -> Vec<String> {
|
||||
let mut files = vec![
|
||||
"tokenizer.json".to_string(),
|
||||
"tokenizer_config.json".to_string(),
|
||||
"config.json".to_string(),
|
||||
];
|
||||
|
||||
// Add model weights based on quantization
|
||||
if model_id.contains("GGUF") || quant != QuantPreset::None {
|
||||
// Look for GGUF files
|
||||
files.push(format!("*{}", quant.gguf_suffix()));
|
||||
} else {
|
||||
// SafeTensors format
|
||||
files.push("model.safetensors".to_string());
|
||||
}
|
||||
|
||||
// Add special tokens and chat template if available
|
||||
files.push("special_tokens_map.json".to_string());
|
||||
files.push("generation_config.json".to_string());
|
||||
|
||||
files
|
||||
}
|
||||
|
||||
/// Check if a model is already downloaded
|
||||
pub async fn is_model_downloaded(model: &str, cache_dir: &str) -> bool {
|
||||
let model_id = resolve_model_id(model);
|
||||
let model_cache_dir = PathBuf::from(cache_dir).join("models").join(&model_id);
|
||||
|
||||
// Check for tokenizer and at least one model file
|
||||
let tokenizer_exists = model_cache_dir.join("tokenizer.json").exists();
|
||||
let has_weights = tokio::fs::read_dir(&model_cache_dir)
|
||||
.await
|
||||
.ok()
|
||||
.map(|mut dir| {
|
||||
use futures::StreamExt;
|
||||
// Simplified check - just see if directory exists and has files
|
||||
true
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
tokenizer_exists && has_weights
|
||||
}
|
||||
|
||||
/// Get the path to a downloaded model
|
||||
pub fn get_model_path(model: &str, cache_dir: &str) -> PathBuf {
|
||||
let model_id = resolve_model_id(model);
|
||||
PathBuf::from(cache_dir).join("models").join(&model_id)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_files_to_download() {
|
||||
let files = get_files_to_download("test/model", QuantPreset::Q4K);
|
||||
assert!(files.contains(&"tokenizer.json".to_string()));
|
||||
assert!(files.iter().any(|f| f.contains("Q4_K_M")));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user