Add a dequantize command to tensor-tools. (#1565)

* Add a dequantize command to tensor-tools.

* Clippy fixes.
This commit is contained in:
Laurent Mazare
2024-01-11 11:21:01 +01:00
committed by GitHub
parent 2480c5dbdd
commit 0fc95c9f0c

View File

@ -102,7 +102,7 @@ enum Command {
}, },
Quantize { Quantize {
/// The input file, in gguf format. /// The input file(s), in safetensors format.
in_file: Vec<std::path::PathBuf>, in_file: Vec<std::path::PathBuf>,
/// The output file, in gguf format. /// The output file, in gguf format.
@ -117,6 +117,15 @@ enum Command {
#[arg(long, value_enum, default_value_t = QuantizationMode::Llama)] #[arg(long, value_enum, default_value_t = QuantizationMode::Llama)]
mode: QuantizationMode, mode: QuantizationMode,
}, },
Dequantize {
/// The input file, in gguf format.
in_file: std::path::PathBuf,
/// The output file, in safetensors format.
#[arg(long)]
out_file: std::path::PathBuf,
},
} }
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
@ -285,6 +294,19 @@ fn run_quantize_safetensors(
Ok(()) Ok(())
} }
fn run_dequantize(in_file: std::path::PathBuf, out_file: std::path::PathBuf) -> Result<()> {
let mut in_file = std::fs::File::open(in_file)?;
let content = gguf_file::Content::read(&mut in_file)?;
let mut tensors = std::collections::HashMap::new();
for (tensor_name, _) in content.tensor_infos.iter() {
let tensor = content.tensor(&mut in_file, tensor_name)?;
let tensor = tensor.dequantize(&Device::Cpu)?;
tensors.insert(tensor_name.to_string(), tensor);
}
candle_core::safetensors::save(&tensors, out_file)?;
Ok(())
}
fn run_quantize( fn run_quantize(
in_files: &[std::path::PathBuf], in_files: &[std::path::PathBuf],
out_file: std::path::PathBuf, out_file: std::path::PathBuf,
@ -379,6 +401,7 @@ fn main() -> anyhow::Result<()> {
quantization, quantization,
mode, mode,
} => run_quantize(&in_file, out_file, quantization, mode)?, } => run_quantize(&in_file, out_file, quantization, mode)?,
Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file)?,
} }
Ok(()) Ok(())
} }