mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Add the tensor-tools binary. (#510)
This commit is contained in:
72
candle-core/examples/tensor-tools.rs
Normal file
72
candle-core/examples/tensor-tools.rs
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
use candle_core::Result;
|
||||||
|
use clap::{Parser, Subcommand};
|
||||||
|
|
||||||
|
#[derive(Subcommand, Debug, Clone)]
|
||||||
|
enum Command {
|
||||||
|
Ls { files: Vec<std::path::PathBuf> },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug, Clone)]
|
||||||
|
struct Args {
|
||||||
|
/// Enable verbose mode.
|
||||||
|
#[arg(short, long)]
|
||||||
|
verbose: bool,
|
||||||
|
|
||||||
|
#[command(subcommand)]
|
||||||
|
command: Command,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_ls(file: &std::path::PathBuf) -> Result<()> {
|
||||||
|
match file.extension().and_then(|e| e.to_str()) {
|
||||||
|
Some("npz") => {
|
||||||
|
let tensors = candle_core::npy::NpzTensors::new(file)?;
|
||||||
|
let mut names = tensors.names();
|
||||||
|
names.sort();
|
||||||
|
for name in names {
|
||||||
|
let shape_dtype = match tensors.get_shape_and_dtype(name) {
|
||||||
|
Ok((shape, dtype)) => format!("[{shape:?}; {dtype:?}]"),
|
||||||
|
Err(err) => err.to_string(),
|
||||||
|
};
|
||||||
|
println!("{name}: {shape_dtype}")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some("safetensor") | Some("safetensors") => {
|
||||||
|
let tensors = unsafe { candle_core::safetensors::MmapedFile::new(file)? };
|
||||||
|
let tensors = tensors.deserialize()?;
|
||||||
|
let mut tensors = tensors.tensors();
|
||||||
|
tensors.sort_by(|a, b| a.0.cmp(&b.0));
|
||||||
|
for (name, view) in tensors.iter() {
|
||||||
|
let dtype = view.dtype();
|
||||||
|
let dtype = match candle_core::DType::try_from(dtype) {
|
||||||
|
Ok(dtype) => format!("{dtype:?}"),
|
||||||
|
Err(_) => format!("{dtype:?}"),
|
||||||
|
};
|
||||||
|
let shape = view.shape();
|
||||||
|
println!("{name}: [{shape:?}; {dtype}]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(_) => {
|
||||||
|
println!("{file:?}: unsupported file extension")
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
println!("{file:?}: no file extension")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> anyhow::Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
match args.command {
|
||||||
|
Command::Ls { files } => {
|
||||||
|
let multiple_files = files.len() > 1;
|
||||||
|
for file in files.iter() {
|
||||||
|
if multiple_files {
|
||||||
|
println!("--- {file:?} ---");
|
||||||
|
}
|
||||||
|
run_ls(file)?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -361,6 +361,25 @@ impl NpzTensors {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn names(&self) -> Vec<&String> {
|
||||||
|
self.index_per_name.keys().collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// This only returns the shape and dtype for a named tensor. Compared to `get`, this avoids
|
||||||
|
/// reading the whole tensor data.
|
||||||
|
pub fn get_shape_and_dtype(&self, name: &str) -> Result<(Shape, DType)> {
|
||||||
|
let index = match self.index_per_name.get(name) {
|
||||||
|
None => crate::bail!("cannot find tensor {name}"),
|
||||||
|
Some(index) => *index,
|
||||||
|
};
|
||||||
|
let zip_reader = BufReader::new(File::open(&self.path)?);
|
||||||
|
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||||
|
let mut reader = zip.by_index(index)?;
|
||||||
|
let header = read_header(&mut reader)?;
|
||||||
|
let header = Header::parse(&header)?;
|
||||||
|
Ok((header.shape(), header.descr))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
||||||
let index = match self.index_per_name.get(name) {
|
let index = match self.index_per_name.get(name) {
|
||||||
None => return Ok(None),
|
None => return Ok(None),
|
||||||
|
Reference in New Issue
Block a user