diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs new file mode 100644 index 00000000..03ea923b --- /dev/null +++ b/candle-core/examples/tensor-tools.rs @@ -0,0 +1,72 @@ +use candle_core::Result; +use clap::{Parser, Subcommand}; + +#[derive(Subcommand, Debug, Clone)] +enum Command { + Ls { files: Vec }, +} + +#[derive(Parser, Debug, Clone)] +struct Args { + /// Enable verbose mode. + #[arg(short, long)] + verbose: bool, + + #[command(subcommand)] + command: Command, +} + +fn run_ls(file: &std::path::PathBuf) -> Result<()> { + match file.extension().and_then(|e| e.to_str()) { + Some("npz") => { + let tensors = candle_core::npy::NpzTensors::new(file)?; + let mut names = tensors.names(); + names.sort(); + for name in names { + let shape_dtype = match tensors.get_shape_and_dtype(name) { + Ok((shape, dtype)) => format!("[{shape:?}; {dtype:?}]"), + Err(err) => err.to_string(), + }; + println!("{name}: {shape_dtype}") + } + } + Some("safetensor") | Some("safetensors") => { + let tensors = unsafe { candle_core::safetensors::MmapedFile::new(file)? }; + let tensors = tensors.deserialize()?; + let mut tensors = tensors.tensors(); + tensors.sort_by(|a, b| a.0.cmp(&b.0)); + for (name, view) in tensors.iter() { + let dtype = view.dtype(); + let dtype = match candle_core::DType::try_from(dtype) { + Ok(dtype) => format!("{dtype:?}"), + Err(_) => format!("{dtype:?}"), + }; + let shape = view.shape(); + println!("{name}: [{shape:?}; {dtype}]") + } + } + Some(_) => { + println!("{file:?}: unsupported file extension") + } + None => { + println!("{file:?}: no file extension") + } + } + Ok(()) +} + +fn main() -> anyhow::Result<()> { + let args = Args::parse(); + match args.command { + Command::Ls { files } => { + let multiple_files = files.len() > 1; + for file in files.iter() { + if multiple_files { + println!("--- {file:?} ---"); + } + run_ls(file)? + } + } + } + Ok(()) +} diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index 2e394b06..b7cfbda0 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -361,6 +361,25 @@ impl NpzTensors { }) } + pub fn names(&self) -> Vec<&String> { + self.index_per_name.keys().collect() + } + + /// This only returns the shape and dtype for a named tensor. Compared to `get`, this avoids + /// reading the whole tensor data. + pub fn get_shape_and_dtype(&self, name: &str) -> Result<(Shape, DType)> { + let index = match self.index_per_name.get(name) { + None => crate::bail!("cannot find tensor {name}"), + Some(index) => *index, + }; + let zip_reader = BufReader::new(File::open(&self.path)?); + let mut zip = zip::ZipArchive::new(zip_reader)?; + let mut reader = zip.by_index(index)?; + let header = read_header(&mut reader)?; + let header = Header::parse(&header)?; + Ok((header.shape(), header.descr)) + } + pub fn get(&self, name: &str) -> Result> { let index = match self.index_per_name.get(name) { None => return Ok(None),