Add ggml support to tensor-tools (#512)

* Pickle work-in-progress.

* More unpickling.

* More pickling.

* Proper handling of setitems.

* Clippy.

* Again more pickling.

* Restore the example.

* Add enough pickle support to get the list of tensors.

* Read the data from zip files.

* Retrieve the tensor shape.

* Extract the size and dtype.

* More storage types.

* Improve the destructuring.

* Also support ggml files.
This commit is contained in:
Laurent Mazare
2023-08-19 11:45:22 +01:00
committed by GitHub
parent ad33715c61
commit f861a9df6e

View File

@ -1,9 +1,38 @@
use candle_core::Result; use candle_core::Result;
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand, ValueEnum};
#[derive(ValueEnum, Debug, Clone)]
enum Format {
Safetensors,
Npz,
Ggml,
PyTorch,
Pickle,
}
impl Format {
fn infer<P: AsRef<std::path::Path>>(p: P) -> Option<Self> {
p.as_ref()
.extension()
.and_then(|e| e.to_str())
.and_then(|e| match e {
"safetensors" | "safetensor" => Some(Self::Safetensors),
"npz" => Some(Self::Npz),
"pth" | "pt" => Some(Self::PyTorch),
"ggml" => Some(Self::Ggml),
_ => None,
})
}
}
#[derive(Subcommand, Debug, Clone)] #[derive(Subcommand, Debug, Clone)]
enum Command { enum Command {
Ls { files: Vec<std::path::PathBuf> }, Ls {
files: Vec<std::path::PathBuf>,
/// The file format to use, if unspecified infer from the file extension.
#[arg(long, value_enum)]
format: Option<Format>,
},
} }
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
@ -16,9 +45,21 @@ struct Args {
command: Command, command: Command,
} }
fn run_ls(file: &std::path::PathBuf) -> Result<()> { fn run_ls(file: &std::path::PathBuf, format: Option<Format>) -> Result<()> {
match file.extension().and_then(|e| e.to_str()) { let format = match format {
Some("npz") => { Some(format) => format,
None => match Format::infer(file) {
Some(format) => format,
None => {
println!(
"{file:?}: cannot infer format from file extension, use the --format flag"
);
return Ok(());
}
},
};
match format {
Format::Npz => {
let tensors = candle_core::npy::NpzTensors::new(file)?; let tensors = candle_core::npy::NpzTensors::new(file)?;
let mut names = tensors.names(); let mut names = tensors.names();
names.sort(); names.sort();
@ -30,7 +71,7 @@ fn run_ls(file: &std::path::PathBuf) -> Result<()> {
println!("{name}: {shape_dtype}") println!("{name}: {shape_dtype}")
} }
} }
Some("safetensor") | Some("safetensors") => { Format::Safetensors => {
let tensors = unsafe { candle_core::safetensors::MmapedFile::new(file)? }; let tensors = unsafe { candle_core::safetensors::MmapedFile::new(file)? };
let tensors = tensors.deserialize()?; let tensors = tensors.deserialize()?;
let mut tensors = tensors.tensors(); let mut tensors = tensors.tensors();
@ -45,14 +86,14 @@ fn run_ls(file: &std::path::PathBuf) -> Result<()> {
println!("{name}: [{shape:?}; {dtype}]") println!("{name}: [{shape:?}; {dtype}]")
} }
} }
Some("pt") | Some("pth") => { Format::PyTorch => {
let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?; let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?;
tensors.sort_by(|a, b| a.0.cmp(&b.0)); tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, dtype, shape) in tensors.iter() { for (name, dtype, shape) in tensors.iter() {
println!("{name}: [{shape:?}; {dtype:?}]") println!("{name}: [{shape:?}; {dtype:?}]")
} }
} }
Some("pkl") => { Format::Pickle => {
let file = std::fs::File::open(file)?; let file = std::fs::File::open(file)?;
let mut reader = std::io::BufReader::new(file); let mut reader = std::io::BufReader::new(file);
let mut stack = candle_core::pickle::Stack::empty(); let mut stack = candle_core::pickle::Stack::empty();
@ -61,11 +102,14 @@ fn run_ls(file: &std::path::PathBuf) -> Result<()> {
println!("{i} {obj:?}"); println!("{i} {obj:?}");
} }
} }
Some(_) => { Format::Ggml => {
println!("{file:?}: unsupported file extension") let mut file = std::fs::File::open(file)?;
} let content = candle_core::quantized::ggml_file::Content::read(&mut file)?;
None => { let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
println!("{file:?}: no file extension") tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, qtensor) in tensors.iter() {
println!("{name}: [{:?}; {:?}]", qtensor.shape(), qtensor.dtype());
}
} }
} }
Ok(()) Ok(())
@ -74,13 +118,13 @@ fn run_ls(file: &std::path::PathBuf) -> Result<()> {
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
let args = Args::parse(); let args = Args::parse();
match args.command { match args.command {
Command::Ls { files } => { Command::Ls { files, format } => {
let multiple_files = files.len() > 1; let multiple_files = files.len() > 1;
for file in files.iter() { for file in files.iter() {
if multiple_files { if multiple_files {
println!("--- {file:?} ---"); println!("--- {file:?} ---");
} }
run_ls(file)? run_ls(file, format.clone())?
} }
} }
} }