diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs index 229ed489..1e01d8b9 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/candle-core/examples/tensor-tools.rs @@ -88,9 +88,15 @@ fn run_ls(file: &std::path::PathBuf, format: Option) -> Result<()> { } Format::PyTorch => { let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?; - tensors.sort_by(|a, b| a.0.cmp(&b.0)); - for (name, dtype, shape) in tensors.iter() { - println!("{name}: [{shape:?}; {dtype:?}]") + tensors.sort_by(|a, b| a.name.cmp(&b.name)); + for tensor_info in tensors.iter() { + println!( + "{}: [{:?}; {:?}] {:?}", + tensor_info.name, + tensor_info.layout.shape(), + tensor_info.dtype, + tensor_info.path, + ) } } Format::Pickle => { diff --git a/candle-core/src/layout.rs b/candle-core/src/layout.rs index dc532248..bf346cf2 100644 --- a/candle-core/src/layout.rs +++ b/candle-core/src/layout.rs @@ -9,6 +9,14 @@ pub struct Layout { } impl Layout { + pub fn new(shape: Shape, stride: Vec, start_offset: usize) -> Self { + Self { + shape, + stride, + start_offset, + } + } + pub fn contiguous_with_offset>(shape: S, start_offset: usize) -> Self { let shape = shape.into(); let stride = shape.stride_contiguous(); diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index 059a0d9c..8a69c1e9 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -1,11 +1,13 @@ // Just enough pickle support to be able to read PyTorch checkpoints. // This hardcodes objects that are required for tensor reading, we may want to make this a bit more // composable/tensor agnostic at some point. -use crate::{DType, Error as E, Result}; +use crate::{DType, Error as E, Layout, Result}; use byteorder::{LittleEndian, ReadBytesExt}; use std::collections::HashMap; use std::io::BufRead; +const VERBOSE: bool = false; + // https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/ #[repr(u8)] #[derive(Debug, Eq, PartialEq, Clone)] @@ -352,7 +354,9 @@ impl Stack { match op_code { OpCode::Proto => { let version = r.read_u8()?; - println!("proto {version}"); + if VERBOSE { + println!("proto {version}"); + } } OpCode::Global => { let module_name = read_to_newline(r)?; @@ -486,11 +490,14 @@ impl From for E { // https://github.com/pytorch/pytorch/blob/4eac43d046ded0f0a5a5fa8db03eb40f45bf656e/torch/_utils.py#L198 // Arguments: storage, storage_offset, size, stride, requires_grad, backward_hooks -fn rebuild_args(args: Object) -> Result<(Vec, DType)> { +fn rebuild_args(args: Object) -> Result<(Layout, DType, String)> { let mut args = args.tuple()?; + let stride = Vec::::try_from(args.remove(3))?; let size = Vec::::try_from(args.remove(2))?; + let offset = args.remove(1).int()? as usize; let storage = args.remove(0).persistent_load()?; let mut storage = storage.tuple()?; + let path = storage.remove(2).unicode()?; let (_module_name, class_name) = storage.remove(1).class()?; let dtype = match class_name.as_str() { "FloatStorage" => DType::F32, @@ -502,12 +509,19 @@ fn rebuild_args(args: Object) -> Result<(Vec, DType)> { crate::bail!("unsupported storage type {other}") } }; - Ok((size, dtype)) + let layout = Layout::new(crate::Shape::from(size), stride, offset); + Ok((layout, dtype, path)) } -pub fn read_pth_tensor_info>( - file: P, -) -> Result)>> { +#[derive(Debug, Clone)] +pub struct TensorInfo { + pub name: String, + pub dtype: DType, + pub layout: Layout, + pub path: std::path::PathBuf, +} + +pub fn read_pth_tensor_info>(file: P) -> Result> { let file = std::fs::File::open(file)?; let zip_reader = std::io::BufReader::new(file); let mut zip = zip::ZipArchive::new(zip_reader)?; @@ -516,26 +530,44 @@ pub fn read_pth_tensor_info>( .map(|f| f.to_string()) .collect::>(); - let mut tensor_info = vec![]; - for name in zip_file_names.iter() { - if !name.ends_with("data.pkl") { + let mut tensor_infos = vec![]; + for file_name in zip_file_names.iter() { + if !file_name.ends_with("data.pkl") { continue; } - let reader = zip.by_name(name)?; + let dir_name = std::path::PathBuf::from(file_name.strip_suffix(".pkl").unwrap()); + let reader = zip.by_name(file_name)?; let mut reader = std::io::BufReader::new(reader); let mut stack = Stack::empty(); stack.read_loop(&mut reader)?; let obj = stack.finalize()?; + if VERBOSE { + println!("{obj:?}"); + } if let Object::Dict(key_values) = obj { - for (key, value) in key_values.into_iter() { - let key = match key.unicode() { - Ok(key) => key, + for (name, value) in key_values.into_iter() { + let name = match name.unicode() { + Ok(name) => name, Err(_) => continue, }; let (callable, args) = match value.reduce() { Ok(callable_args) => callable_args, _ => continue, }; + let (callable, args) = match callable { + Object::Class { + module_name, + class_name, + } if module_name == "torch._tensor" + && class_name == "_rebuild_from_type_v2" => + { + let mut args = args.tuple()?; + let callable = args.remove(0); + let args = args.remove(1); + (callable, args) + } + _ => (callable, args), + }; match callable { Object::Class { module_name, @@ -544,13 +576,22 @@ pub fn read_pth_tensor_info>( _ => continue, }; match rebuild_args(args) { - Ok((size, dtype)) => tensor_info.push((key, dtype, size)), + Ok((layout, dtype, file_path)) => { + let mut path = dir_name.clone(); + path.push(file_path); + tensor_infos.push(TensorInfo { + name, + dtype, + layout, + path, + }) + } Err(err) => { - eprintln!("skipping {key}: {err:?}") + eprintln!("skipping {name}: {err:?}") } } } } } - Ok(tensor_info) + Ok(tensor_infos) }