After rebase.

This commit is contained in:
Nicolas Patry
2024-01-11 17:17:55 +01:00
parent 3aefc709c7
commit 9ef040338d

View File

@ -286,13 +286,17 @@ fn run_quantize_safetensors(
Ok(())
}
fn run_dequantize(in_file: std::path::PathBuf, out_file: std::path::PathBuf) -> Result<()> {
fn run_dequantize(
in_file: std::path::PathBuf,
out_file: std::path::PathBuf,
device: &Device,
) -> Result<()> {
let mut in_file = std::fs::File::open(in_file)?;
let content = gguf_file::Content::read(&mut in_file)?;
let mut tensors = std::collections::HashMap::new();
for (tensor_name, _) in content.tensor_infos.iter() {
let tensor = content.tensor(&mut in_file, tensor_name)?;
let tensor = tensor.dequantize(&Device::Cpu)?;
let tensor = content.tensor(&mut in_file, tensor_name, device)?;
let tensor = tensor.dequantize(device)?;
tensors.insert(tensor_name.to_string(), tensor);
}
candle_core::safetensors::save(&tensors, out_file)?;
@ -379,7 +383,7 @@ fn main() -> anyhow::Result<()> {
quantization,
mode,
} => run_quantize(&in_file, out_file, quantization, mode, &device)?,
Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file)?,
Command::Dequantize { in_file, out_file } => run_dequantize(in_file, out_file, &device)?,
}
Ok(())
}