mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Add ggml support to tensor-tools (#512)
* Pickle work-in-progress. * More unpickling. * More pickling. * Proper handling of setitems. * Clippy. * Again more pickling. * Restore the example. * Add enough pickle support to get the list of tensors. * Read the data from zip files. * Retrieve the tensor shape. * Extract the size and dtype. * More storage types. * Improve the destructuring. * Also support ggml files.
This commit is contained in:
@ -1,9 +1,38 @@
|
|||||||
use candle_core::Result;
|
use candle_core::Result;
|
||||||
use clap::{Parser, Subcommand};
|
use clap::{Parser, Subcommand, ValueEnum};
|
||||||
|
|
||||||
|
#[derive(ValueEnum, Debug, Clone)]
|
||||||
|
enum Format {
|
||||||
|
Safetensors,
|
||||||
|
Npz,
|
||||||
|
Ggml,
|
||||||
|
PyTorch,
|
||||||
|
Pickle,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Format {
|
||||||
|
fn infer<P: AsRef<std::path::Path>>(p: P) -> Option<Self> {
|
||||||
|
p.as_ref()
|
||||||
|
.extension()
|
||||||
|
.and_then(|e| e.to_str())
|
||||||
|
.and_then(|e| match e {
|
||||||
|
"safetensors" | "safetensor" => Some(Self::Safetensors),
|
||||||
|
"npz" => Some(Self::Npz),
|
||||||
|
"pth" | "pt" => Some(Self::PyTorch),
|
||||||
|
"ggml" => Some(Self::Ggml),
|
||||||
|
_ => None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Subcommand, Debug, Clone)]
|
#[derive(Subcommand, Debug, Clone)]
|
||||||
enum Command {
|
enum Command {
|
||||||
Ls { files: Vec<std::path::PathBuf> },
|
Ls {
|
||||||
|
files: Vec<std::path::PathBuf>,
|
||||||
|
/// The file format to use, if unspecified infer from the file extension.
|
||||||
|
#[arg(long, value_enum)]
|
||||||
|
format: Option<Format>,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug, Clone)]
|
#[derive(Parser, Debug, Clone)]
|
||||||
@ -16,9 +45,21 @@ struct Args {
|
|||||||
command: Command,
|
command: Command,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn run_ls(file: &std::path::PathBuf) -> Result<()> {
|
fn run_ls(file: &std::path::PathBuf, format: Option<Format>) -> Result<()> {
|
||||||
match file.extension().and_then(|e| e.to_str()) {
|
let format = match format {
|
||||||
Some("npz") => {
|
Some(format) => format,
|
||||||
|
None => match Format::infer(file) {
|
||||||
|
Some(format) => format,
|
||||||
|
None => {
|
||||||
|
println!(
|
||||||
|
"{file:?}: cannot infer format from file extension, use the --format flag"
|
||||||
|
);
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
},
|
||||||
|
};
|
||||||
|
match format {
|
||||||
|
Format::Npz => {
|
||||||
let tensors = candle_core::npy::NpzTensors::new(file)?;
|
let tensors = candle_core::npy::NpzTensors::new(file)?;
|
||||||
let mut names = tensors.names();
|
let mut names = tensors.names();
|
||||||
names.sort();
|
names.sort();
|
||||||
@ -30,7 +71,7 @@ fn run_ls(file: &std::path::PathBuf) -> Result<()> {
|
|||||||
println!("{name}: {shape_dtype}")
|
println!("{name}: {shape_dtype}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Some("safetensor") | Some("safetensors") => {
|
Format::Safetensors => {
|
||||||
let tensors = unsafe { candle_core::safetensors::MmapedFile::new(file)? };
|
let tensors = unsafe { candle_core::safetensors::MmapedFile::new(file)? };
|
||||||
let tensors = tensors.deserialize()?;
|
let tensors = tensors.deserialize()?;
|
||||||
let mut tensors = tensors.tensors();
|
let mut tensors = tensors.tensors();
|
||||||
@ -45,14 +86,14 @@ fn run_ls(file: &std::path::PathBuf) -> Result<()> {
|
|||||||
println!("{name}: [{shape:?}; {dtype}]")
|
println!("{name}: [{shape:?}; {dtype}]")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Some("pt") | Some("pth") => {
|
Format::PyTorch => {
|
||||||
let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?;
|
let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?;
|
||||||
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||||
for (name, dtype, shape) in tensors.iter() {
|
for (name, dtype, shape) in tensors.iter() {
|
||||||
println!("{name}: [{shape:?}; {dtype:?}]")
|
println!("{name}: [{shape:?}; {dtype:?}]")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Some("pkl") => {
|
Format::Pickle => {
|
||||||
let file = std::fs::File::open(file)?;
|
let file = std::fs::File::open(file)?;
|
||||||
let mut reader = std::io::BufReader::new(file);
|
let mut reader = std::io::BufReader::new(file);
|
||||||
let mut stack = candle_core::pickle::Stack::empty();
|
let mut stack = candle_core::pickle::Stack::empty();
|
||||||
@ -61,11 +102,14 @@ fn run_ls(file: &std::path::PathBuf) -> Result<()> {
|
|||||||
println!("{i} {obj:?}");
|
println!("{i} {obj:?}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Some(_) => {
|
Format::Ggml => {
|
||||||
println!("{file:?}: unsupported file extension")
|
let mut file = std::fs::File::open(file)?;
|
||||||
}
|
let content = candle_core::quantized::ggml_file::Content::read(&mut file)?;
|
||||||
None => {
|
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
|
||||||
println!("{file:?}: no file extension")
|
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||||
|
for (name, qtensor) in tensors.iter() {
|
||||||
|
println!("{name}: [{:?}; {:?}]", qtensor.shape(), qtensor.dtype());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -74,13 +118,13 @@ fn run_ls(file: &std::path::PathBuf) -> Result<()> {
|
|||||||
fn main() -> anyhow::Result<()> {
|
fn main() -> anyhow::Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
match args.command {
|
match args.command {
|
||||||
Command::Ls { files } => {
|
Command::Ls { files, format } => {
|
||||||
let multiple_files = files.len() > 1;
|
let multiple_files = files.len() > 1;
|
||||||
for file in files.iter() {
|
for file in files.iter() {
|
||||||
if multiple_files {
|
if multiple_files {
|
||||||
println!("--- {file:?} ---");
|
println!("--- {file:?} ---");
|
||||||
}
|
}
|
||||||
run_ls(file)?
|
run_ls(file, format.clone())?
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user