From f861a9df6ef35bf5e2df5891d3af029e9139b0d8 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 19 Aug 2023 11:45:22 +0100 Subject: [PATCH] Add ggml support to tensor-tools (#512) * Pickle work-in-progress. * More unpickling. * More pickling. * Proper handling of setitems. * Clippy. * Again more pickling. * Restore the example. * Add enough pickle support to get the list of tensors. * Read the data from zip files. * Retrieve the tensor shape. * Extract the size and dtype. * More storage types. * Improve the destructuring. * Also support ggml files. --- candle-core/examples/tensor-tools.rs | 74 ++++++++++++++++++++++------ 1 file changed, 59 insertions(+), 15 deletions(-) diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs index 7baee582..229ed489 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/candle-core/examples/tensor-tools.rs @@ -1,9 +1,38 @@ use candle_core::Result; -use clap::{Parser, Subcommand}; +use clap::{Parser, Subcommand, ValueEnum}; + +#[derive(ValueEnum, Debug, Clone)] +enum Format { + Safetensors, + Npz, + Ggml, + PyTorch, + Pickle, +} + +impl Format { + fn infer>(p: P) -> Option { + p.as_ref() + .extension() + .and_then(|e| e.to_str()) + .and_then(|e| match e { + "safetensors" | "safetensor" => Some(Self::Safetensors), + "npz" => Some(Self::Npz), + "pth" | "pt" => Some(Self::PyTorch), + "ggml" => Some(Self::Ggml), + _ => None, + }) + } +} #[derive(Subcommand, Debug, Clone)] enum Command { - Ls { files: Vec }, + Ls { + files: Vec, + /// The file format to use, if unspecified infer from the file extension. + #[arg(long, value_enum)] + format: Option, + }, } #[derive(Parser, Debug, Clone)] @@ -16,9 +45,21 @@ struct Args { command: Command, } -fn run_ls(file: &std::path::PathBuf) -> Result<()> { - match file.extension().and_then(|e| e.to_str()) { - Some("npz") => { +fn run_ls(file: &std::path::PathBuf, format: Option) -> Result<()> { + let format = match format { + Some(format) => format, + None => match Format::infer(file) { + Some(format) => format, + None => { + println!( + "{file:?}: cannot infer format from file extension, use the --format flag" + ); + return Ok(()); + } + }, + }; + match format { + Format::Npz => { let tensors = candle_core::npy::NpzTensors::new(file)?; let mut names = tensors.names(); names.sort(); @@ -30,7 +71,7 @@ fn run_ls(file: &std::path::PathBuf) -> Result<()> { println!("{name}: {shape_dtype}") } } - Some("safetensor") | Some("safetensors") => { + Format::Safetensors => { let tensors = unsafe { candle_core::safetensors::MmapedFile::new(file)? }; let tensors = tensors.deserialize()?; let mut tensors = tensors.tensors(); @@ -45,14 +86,14 @@ fn run_ls(file: &std::path::PathBuf) -> Result<()> { println!("{name}: [{shape:?}; {dtype}]") } } - Some("pt") | Some("pth") => { + Format::PyTorch => { let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?; tensors.sort_by(|a, b| a.0.cmp(&b.0)); for (name, dtype, shape) in tensors.iter() { println!("{name}: [{shape:?}; {dtype:?}]") } } - Some("pkl") => { + Format::Pickle => { let file = std::fs::File::open(file)?; let mut reader = std::io::BufReader::new(file); let mut stack = candle_core::pickle::Stack::empty(); @@ -61,11 +102,14 @@ fn run_ls(file: &std::path::PathBuf) -> Result<()> { println!("{i} {obj:?}"); } } - Some(_) => { - println!("{file:?}: unsupported file extension") - } - None => { - println!("{file:?}: no file extension") + Format::Ggml => { + let mut file = std::fs::File::open(file)?; + let content = candle_core::quantized::ggml_file::Content::read(&mut file)?; + let mut tensors = content.tensors.into_iter().collect::>(); + tensors.sort_by(|a, b| a.0.cmp(&b.0)); + for (name, qtensor) in tensors.iter() { + println!("{name}: [{:?}; {:?}]", qtensor.shape(), qtensor.dtype()); + } } } Ok(()) @@ -74,13 +118,13 @@ fn run_ls(file: &std::path::PathBuf) -> Result<()> { fn main() -> anyhow::Result<()> { let args = Args::parse(); match args.command { - Command::Ls { files } => { + Command::Ls { files, format } => { let multiple_files = files.len() > 1; for file in files.iter() { if multiple_files { println!("--- {file:?} ---"); } - run_ls(file)? + run_ls(file, format.clone())? } } }