Small tweaks to tensor-tools. (#517)

This commit is contained in:
Laurent Mazare
2023-08-19 16:50:26 +01:00
committed by GitHub
parent 6431140250
commit 551409092e
2 changed files with 21 additions and 12 deletions

View File

@ -6,7 +6,7 @@ enum Format {
Safetensors, Safetensors,
Npz, Npz,
Ggml, Ggml,
PyTorch, Pth,
Pickle, Pickle,
} }
@ -16,9 +16,10 @@ impl Format {
.extension() .extension()
.and_then(|e| e.to_str()) .and_then(|e| e.to_str())
.and_then(|e| match e { .and_then(|e| match e {
// We don't infer any format for .bin as it can be used for ggml or pytorch.
"safetensors" | "safetensor" => Some(Self::Safetensors), "safetensors" | "safetensor" => Some(Self::Safetensors),
"npz" => Some(Self::Npz), "npz" => Some(Self::Npz),
"pth" | "pt" => Some(Self::PyTorch), "pth" | "pt" => Some(Self::Pth),
"ggml" => Some(Self::Ggml), "ggml" => Some(Self::Ggml),
_ => None, _ => None,
}) })
@ -29,18 +30,19 @@ impl Format {
enum Command { enum Command {
Ls { Ls {
files: Vec<std::path::PathBuf>, files: Vec<std::path::PathBuf>,
/// The file format to use, if unspecified infer from the file extension. /// The file format to use, if unspecified infer from the file extension.
#[arg(long, value_enum)] #[arg(long, value_enum)]
format: Option<Format>, format: Option<Format>,
/// Enable verbose mode.
#[arg(short, long)]
verbose: bool,
}, },
} }
#[derive(Parser, Debug, Clone)] #[derive(Parser, Debug, Clone)]
struct Args { struct Args {
/// Enable verbose mode.
#[arg(short, long)]
verbose: bool,
#[command(subcommand)] #[command(subcommand)]
command: Command, command: Command,
} }
@ -86,7 +88,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
println!("{name}: [{shape:?}; {dtype}]") println!("{name}: [{shape:?}; {dtype}]")
} }
} }
Format::PyTorch => { Format::Pth => {
let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?; let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?;
tensors.sort_by(|a, b| a.name.cmp(&b.name)); tensors.sort_by(|a, b| a.name.cmp(&b.name));
for tensor_info in tensors.iter() { for tensor_info in tensors.iter() {
@ -126,13 +128,17 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
let args = Args::parse(); let args = Args::parse();
match args.command { match args.command {
Command::Ls { files, format } => { Command::Ls {
files,
format,
verbose,
} => {
let multiple_files = files.len() > 1; let multiple_files = files.len() > 1;
for file in files.iter() { for file in files.iter() {
if multiple_files { if multiple_files {
println!("--- {file:?} ---"); println!("--- {file:?} ---");
} }
run_ls(file, format.clone(), args.verbose)? run_ls(file, format.clone(), verbose)?
} }
} }
} }

View File

@ -490,13 +490,14 @@ impl From<Object> for E {
// https://github.com/pytorch/pytorch/blob/4eac43d046ded0f0a5a5fa8db03eb40f45bf656e/torch/_utils.py#L198 // https://github.com/pytorch/pytorch/blob/4eac43d046ded0f0a5a5fa8db03eb40f45bf656e/torch/_utils.py#L198
// Arguments: storage, storage_offset, size, stride, requires_grad, backward_hooks // Arguments: storage, storage_offset, size, stride, requires_grad, backward_hooks
fn rebuild_args(args: Object) -> Result<(Layout, DType, String)> { fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
let mut args = args.tuple()?; let mut args = args.tuple()?;
let stride = Vec::<usize>::try_from(args.remove(3))?; let stride = Vec::<usize>::try_from(args.remove(3))?;
let size = Vec::<usize>::try_from(args.remove(2))?; let size = Vec::<usize>::try_from(args.remove(2))?;
let offset = args.remove(1).int()? as usize; let offset = args.remove(1).int()? as usize;
let storage = args.remove(0).persistent_load()?; let storage = args.remove(0).persistent_load()?;
let mut storage = storage.tuple()?; let mut storage = storage.tuple()?;
let storage_size = storage.remove(4).int()? as usize;
let path = storage.remove(2).unicode()?; let path = storage.remove(2).unicode()?;
let (_module_name, class_name) = storage.remove(1).class()?; let (_module_name, class_name) = storage.remove(1).class()?;
let dtype = match class_name.as_str() { let dtype = match class_name.as_str() {
@ -510,7 +511,7 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String)> {
} }
}; };
let layout = Layout::new(crate::Shape::from(size), stride, offset); let layout = Layout::new(crate::Shape::from(size), stride, offset);
Ok((layout, dtype, path)) Ok((layout, dtype, path, storage_size))
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -519,6 +520,7 @@ pub struct TensorInfo {
pub dtype: DType, pub dtype: DType,
pub layout: Layout, pub layout: Layout,
pub path: String, pub path: String,
pub storage_size: usize,
} }
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<TensorInfo>> { pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<TensorInfo>> {
@ -576,7 +578,7 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<Te
_ => continue, _ => continue,
}; };
match rebuild_args(args) { match rebuild_args(args) {
Ok((layout, dtype, file_path)) => { Ok((layout, dtype, file_path, storage_size)) => {
let mut path = dir_name.clone(); let mut path = dir_name.clone();
path.push(file_path); path.push(file_path);
tensor_infos.push(TensorInfo { tensor_infos.push(TensorInfo {
@ -584,6 +586,7 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<Te
dtype, dtype,
layout, layout,
path: path.to_string_lossy().into_owned(), path: path.to_string_lossy().into_owned(),
storage_size,
}) })
} }
Err(err) => { Err(err) => {