diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs index e2d12aa5..19c9d0ca 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/candle-core/examples/tensor-tools.rs @@ -6,7 +6,7 @@ enum Format { Safetensors, Npz, Ggml, - PyTorch, + Pth, Pickle, } @@ -16,9 +16,10 @@ impl Format { .extension() .and_then(|e| e.to_str()) .and_then(|e| match e { + // We don't infer any format for .bin as it can be used for ggml or pytorch. "safetensors" | "safetensor" => Some(Self::Safetensors), "npz" => Some(Self::Npz), - "pth" | "pt" => Some(Self::PyTorch), + "pth" | "pt" => Some(Self::Pth), "ggml" => Some(Self::Ggml), _ => None, }) @@ -29,18 +30,19 @@ impl Format { enum Command { Ls { files: Vec, + /// The file format to use, if unspecified infer from the file extension. #[arg(long, value_enum)] format: Option, + + /// Enable verbose mode. + #[arg(short, long)] + verbose: bool, }, } #[derive(Parser, Debug, Clone)] struct Args { - /// Enable verbose mode. - #[arg(short, long)] - verbose: bool, - #[command(subcommand)] command: Command, } @@ -86,7 +88,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option, verbose: bool) -> R println!("{name}: [{shape:?}; {dtype}]") } } - Format::PyTorch => { + Format::Pth => { let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?; tensors.sort_by(|a, b| a.name.cmp(&b.name)); for tensor_info in tensors.iter() { @@ -126,13 +128,17 @@ fn run_ls(file: &std::path::PathBuf, format: Option, verbose: bool) -> R fn main() -> anyhow::Result<()> { let args = Args::parse(); match args.command { - Command::Ls { files, format } => { + Command::Ls { + files, + format, + verbose, + } => { let multiple_files = files.len() > 1; for file in files.iter() { if multiple_files { println!("--- {file:?} ---"); } - run_ls(file, format.clone(), args.verbose)? + run_ls(file, format.clone(), verbose)? } } } diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index f14a5046..e913935c 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -490,13 +490,14 @@ impl From for E { // https://github.com/pytorch/pytorch/blob/4eac43d046ded0f0a5a5fa8db03eb40f45bf656e/torch/_utils.py#L198 // Arguments: storage, storage_offset, size, stride, requires_grad, backward_hooks -fn rebuild_args(args: Object) -> Result<(Layout, DType, String)> { +fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> { let mut args = args.tuple()?; let stride = Vec::::try_from(args.remove(3))?; let size = Vec::::try_from(args.remove(2))?; let offset = args.remove(1).int()? as usize; let storage = args.remove(0).persistent_load()?; let mut storage = storage.tuple()?; + let storage_size = storage.remove(4).int()? as usize; let path = storage.remove(2).unicode()?; let (_module_name, class_name) = storage.remove(1).class()?; let dtype = match class_name.as_str() { @@ -510,7 +511,7 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String)> { } }; let layout = Layout::new(crate::Shape::from(size), stride, offset); - Ok((layout, dtype, path)) + Ok((layout, dtype, path, storage_size)) } #[derive(Debug, Clone)] @@ -519,6 +520,7 @@ pub struct TensorInfo { pub dtype: DType, pub layout: Layout, pub path: String, + pub storage_size: usize, } pub fn read_pth_tensor_info>(file: P) -> Result> { @@ -576,7 +578,7 @@ pub fn read_pth_tensor_info>(file: P) -> Result continue, }; match rebuild_args(args) { - Ok((layout, dtype, file_path)) => { + Ok((layout, dtype, file_path, storage_size)) => { let mut path = dir_name.clone(); path.push(file_path); tensor_infos.push(TensorInfo { @@ -584,6 +586,7 @@ pub fn read_pth_tensor_info>(file: P) -> Result {