diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs index f45cbc7e..2bc1fa2e 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/candle-core/examples/tensor-tools.rs @@ -1,15 +1,63 @@ -use candle_core::{Device, Result}; +use candle_core::quantized::{gguf_file, k_quants, QTensor}; +use candle_core::{Device, Result, Tensor}; use clap::{Parser, Subcommand, ValueEnum}; use rayon::prelude::*; +#[derive(ValueEnum, Debug, Clone)] +enum QuantizationMode { + /// The default quantization includes all 2d tensors, except the output tensor which always + /// uses Q6_K. + Llama, +} + +impl QuantizationMode { + fn quantize( + &self, + name: &str, + tensor: QTensor, + default: fn(&Tensor) -> Result, + ) -> Result { + match self { + Self::Llama => { + // Same behavior as the llama.cpp quantization. + let should_quantize = name.ends_with(".weight") && tensor.rank() == 2; + if should_quantize { + let tensor = tensor.dequantize(&Device::Cpu)?; + if name == "output.weight" { + QTensor::quantize::(&tensor) + } else { + default(&tensor) + } + } else { + Ok(tensor) + } + } + } + } +} + #[derive(ValueEnum, Debug, Clone)] enum Quantization { + #[value(name = "q4_0")] + Q4_0, + #[value(name = "q4_1")] + Q4_1, + #[value(name = "q5_0")] + Q5_0, + #[value(name = "q5_1")] + Q5_1, + #[value(name = "q8_0")] + Q8_0, + #[value(name = "q8_1")] + Q8_1, Q2k, Q3k, Q4k, Q5k, Q6k, Q8k, + F16, + F32, } #[derive(ValueEnum, Debug, Clone)] @@ -62,6 +110,10 @@ enum Command { /// The quantization schema to apply. #[arg(long, value_enum)] quantization: Quantization, + + /// Which tensor to quantize. + #[arg(long, value_enum, default_value_t = QuantizationMode::Llama)] + mode: QuantizationMode, }, } @@ -147,7 +199,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option, verbose: bool) -> R } Format::Gguf => { let mut file = std::fs::File::open(file)?; - let content = candle_core::quantized::gguf_file::Content::read(&mut file)?; + let content = gguf_file::Content::read(&mut file)?; if verbose { let mut metadata = content.metadata.into_iter().collect::>(); metadata.sort_by(|a, b| a.0.cmp(&b.0)); @@ -170,14 +222,31 @@ fn run_quantize( in_file: std::path::PathBuf, out_file: std::path::PathBuf, q: Quantization, + qmode: QuantizationMode, ) -> Result<()> { - use candle_core::quantized::{gguf_file, k_quants, QTensor}; // Open the out file early so as to fail directly on missing directories etc. let mut out_file = std::fs::File::create(out_file)?; let mut in_ = std::fs::File::open(&in_file)?; let content = gguf_file::Content::read(&mut in_)?; println!("tensors: {}", content.tensor_infos.len()); + let quantize_fn = match q { + Quantization::Q4_0 => QTensor::quantize::, + Quantization::Q4_1 => QTensor::quantize::, + Quantization::Q5_0 => QTensor::quantize::, + Quantization::Q5_1 => QTensor::quantize::, + Quantization::Q8_0 => QTensor::quantize::, + Quantization::Q8_1 => QTensor::quantize::, + Quantization::Q2k => QTensor::quantize::, + Quantization::Q3k => QTensor::quantize::, + Quantization::Q4k => QTensor::quantize::, + Quantization::Q5k => QTensor::quantize::, + Quantization::Q6k => QTensor::quantize::, + Quantization::Q8k => QTensor::quantize::, + Quantization::F16 => QTensor::quantize::, + Quantization::F32 => QTensor::quantize::, + }; + let qtensors = content .tensor_infos .par_iter() @@ -185,17 +254,7 @@ fn run_quantize( println!(" quantizing {name}"); let mut in_file = std::fs::File::open(&in_file)?; let tensor = content.tensor(&mut in_file, name)?; - let tensor = tensor.dequantize(&Device::Cpu)?; - // TODO: Only quantize the linear weights, and quantize the final layer weights - // differently from the rest. - let tensor = match q { - Quantization::Q2k => QTensor::quantize::(&tensor)?, - Quantization::Q3k => QTensor::quantize::(&tensor)?, - Quantization::Q4k => QTensor::quantize::(&tensor)?, - Quantization::Q5k => QTensor::quantize::(&tensor)?, - Quantization::Q6k => QTensor::quantize::(&tensor)?, - Quantization::Q8k => QTensor::quantize::(&tensor)?, - }; + let tensor = qmode.quantize(name, tensor, quantize_fn)?; Ok((name, tensor)) }) .collect::>>()?; @@ -233,7 +292,8 @@ fn main() -> anyhow::Result<()> { in_file, out_file, quantization, - } => run_quantize(in_file, out_file, quantization)?, + mode, + } => run_quantize(in_file, out_file, quantization, mode)?, } Ok(()) } diff --git a/candle-core/src/quantized/mod.rs b/candle-core/src/quantized/mod.rs index cb788779..d87d2d5a 100644 --- a/candle-core/src/quantized/mod.rs +++ b/candle-core/src/quantized/mod.rs @@ -202,6 +202,10 @@ impl QTensor { self.data.dtype() } + pub fn rank(&self) -> usize { + self.shape.rank() + } + pub fn shape(&self) -> &Shape { &self.shape }