Add the tensor-tools binary. (#510)

This commit is contained in:
Laurent Mazare
2023-08-19 09:06:44 +01:00
committed by GitHub
parent 42e1cc8062
commit 90ff04e77e
2 changed files with 91 additions and 0 deletions

View File

@ -0,0 +1,72 @@
use candle_core::Result;
use clap::{Parser, Subcommand};
#[derive(Subcommand, Debug, Clone)]
enum Command {
Ls { files: Vec<std::path::PathBuf> },
}
#[derive(Parser, Debug, Clone)]
struct Args {
/// Enable verbose mode.
#[arg(short, long)]
verbose: bool,
#[command(subcommand)]
command: Command,
}
fn run_ls(file: &std::path::PathBuf) -> Result<()> {
match file.extension().and_then(|e| e.to_str()) {
Some("npz") => {
let tensors = candle_core::npy::NpzTensors::new(file)?;
let mut names = tensors.names();
names.sort();
for name in names {
let shape_dtype = match tensors.get_shape_and_dtype(name) {
Ok((shape, dtype)) => format!("[{shape:?}; {dtype:?}]"),
Err(err) => err.to_string(),
};
println!("{name}: {shape_dtype}")
}
}
Some("safetensor") | Some("safetensors") => {
let tensors = unsafe { candle_core::safetensors::MmapedFile::new(file)? };
let tensors = tensors.deserialize()?;
let mut tensors = tensors.tensors();
tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, view) in tensors.iter() {
let dtype = view.dtype();
let dtype = match candle_core::DType::try_from(dtype) {
Ok(dtype) => format!("{dtype:?}"),
Err(_) => format!("{dtype:?}"),
};
let shape = view.shape();
println!("{name}: [{shape:?}; {dtype}]")
}
}
Some(_) => {
println!("{file:?}: unsupported file extension")
}
None => {
println!("{file:?}: no file extension")
}
}
Ok(())
}
fn main() -> anyhow::Result<()> {
let args = Args::parse();
match args.command {
Command::Ls { files } => {
let multiple_files = files.len() > 1;
for file in files.iter() {
if multiple_files {
println!("--- {file:?} ---");
}
run_ls(file)?
}
}
}
Ok(())
}

View File

@ -361,6 +361,25 @@ impl NpzTensors {
})
}
pub fn names(&self) -> Vec<&String> {
self.index_per_name.keys().collect()
}
/// This only returns the shape and dtype for a named tensor. Compared to `get`, this avoids
/// reading the whole tensor data.
pub fn get_shape_and_dtype(&self, name: &str) -> Result<(Shape, DType)> {
let index = match self.index_per_name.get(name) {
None => crate::bail!("cannot find tensor {name}"),
Some(index) => *index,
};
let zip_reader = BufReader::new(File::open(&self.path)?);
let mut zip = zip::ZipArchive::new(zip_reader)?;
let mut reader = zip.by_index(index)?;
let header = read_header(&mut reader)?;
let header = Header::parse(&header)?;
Ok((header.shape(), header.descr))
}
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
let index = match self.index_per_name.get(name) {
None => return Ok(None),