Add the quantize command. (#624)

* Add the quantize command.

* Bugfix for writing gguf files.

* And add a comment.
This commit is contained in:
Laurent Mazare
2023-08-27 11:35:19 +01:00
committed by GitHub
parent 6e485f2deb
commit 7151f2cf63
2 changed files with 77 additions and 2 deletions

View File

@ -1,5 +1,16 @@
use candle_core::Result; use candle_core::{Device, Result};
use clap::{Parser, Subcommand, ValueEnum}; use clap::{Parser, Subcommand, ValueEnum};
use rayon::prelude::*;
#[derive(ValueEnum, Debug, Clone)]
enum Quantization {
Q2k,
Q3k,
Q4k,
Q5k,
Q6k,
Q8k,
}
#[derive(ValueEnum, Debug, Clone)] #[derive(ValueEnum, Debug, Clone)]
enum Format { enum Format {
@ -41,6 +52,17 @@ enum Command {
#[arg(short, long)] #[arg(short, long)]
verbose: bool, verbose: bool,
}, },
Quantize {
/// The input file, in gguf format.
in_file: std::path::PathBuf,
/// The output file, in gguf format.
out_file: std::path::PathBuf,
/// The quantization schema to apply.
#[arg(long, value_enum)]
quantization: Quantization,
},
} }
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
@ -144,6 +166,53 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
Ok(()) Ok(())
} }
fn run_quantize(
in_file: std::path::PathBuf,
out_file: std::path::PathBuf,
q: Quantization,
) -> Result<()> {
use candle_core::quantized::{gguf_file, k_quants, QTensor};
// Open the out file early so as to fail directly on missing directories etc.
let mut out_file = std::fs::File::create(out_file)?;
let mut in_ = std::fs::File::open(&in_file)?;
let content = gguf_file::Content::read(&mut in_)?;
println!("tensors: {}", content.tensor_infos.len());
let qtensors = content
.tensor_infos
.par_iter()
.map(|(name, _)| {
println!(" quantizing {name}");
let mut in_file = std::fs::File::open(&in_file)?;
let tensor = content.tensor(&mut in_file, name)?;
let tensor = tensor.dequantize(&Device::Cpu)?;
// TODO: Only quantize the linear weights, and quantize the final layer weights
// differently from the rest.
let tensor = match q {
Quantization::Q2k => QTensor::quantize::<k_quants::BlockQ2K>(&tensor)?,
Quantization::Q3k => QTensor::quantize::<k_quants::BlockQ3K>(&tensor)?,
Quantization::Q4k => QTensor::quantize::<k_quants::BlockQ4K>(&tensor)?,
Quantization::Q5k => QTensor::quantize::<k_quants::BlockQ5K>(&tensor)?,
Quantization::Q6k => QTensor::quantize::<k_quants::BlockQ6K>(&tensor)?,
Quantization::Q8k => QTensor::quantize::<k_quants::BlockQ8K>(&tensor)?,
};
Ok((name, tensor))
})
.collect::<Result<Vec<_>>>()?;
let qtensors = qtensors
.iter()
.map(|(k, v)| (k.as_str(), v))
.collect::<Vec<_>>();
let metadata = content
.metadata
.iter()
.map(|(k, v)| (k.as_str(), v))
.collect::<Vec<_>>();
gguf_file::write(&mut out_file, metadata.as_slice(), &qtensors)?;
Ok(())
}
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
let args = Args::parse(); let args = Args::parse();
match args.command { match args.command {
@ -160,6 +229,11 @@ fn main() -> anyhow::Result<()> {
run_ls(file, format.clone(), verbose)? run_ls(file, format.clone(), verbose)?
} }
} }
Command::Quantize {
in_file,
out_file,
quantization,
} => run_quantize(in_file, out_file, quantization)?,
} }
Ok(()) Ok(())
} }

View File

@ -292,7 +292,7 @@ impl ValueType {
7 => Self::Bool, 7 => Self::Bool,
8 => Self::String, 8 => Self::String,
9 => Self::Array, 9 => Self::Array,
v => crate::bail!("unrecognized value-type {v}"), v => crate::bail!("unrecognized value-type {v:#08x}"),
}; };
Ok(v) Ok(v)
} }
@ -393,6 +393,7 @@ pub fn write<W: std::io::Seek + std::io::Write>(
w.write_u32::<LittleEndian>(0x46554747)?; w.write_u32::<LittleEndian>(0x46554747)?;
w.write_u32::<LittleEndian>(1)?; // version 1. w.write_u32::<LittleEndian>(1)?; // version 1.
w.write_u32::<LittleEndian>(tensors.len() as u32)?; w.write_u32::<LittleEndian>(tensors.len() as u32)?;
w.write_u32::<LittleEndian>(metadata.len() as u32)?;
for (name, value) in metadata.iter() { for (name, value) in metadata.iter() {
write_string(w, name)?; write_string(w, name)?;
w.write_u32::<LittleEndian>(value.value_type().to_u32())?; w.write_u32::<LittleEndian>(value.value_type().to_u32())?;