diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs index 1e01d8b9..e2d12aa5 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/candle-core/examples/tensor-tools.rs @@ -45,7 +45,7 @@ struct Args { command: Command, } -fn run_ls(file: &std::path::PathBuf, format: Option) -> Result<()> { +fn run_ls(file: &std::path::PathBuf, format: Option, verbose: bool) -> Result<()> { let format = match format { Some(format) => format, None => match Format::infer(file) { @@ -91,12 +91,14 @@ fn run_ls(file: &std::path::PathBuf, format: Option) -> Result<()> { tensors.sort_by(|a, b| a.name.cmp(&b.name)); for tensor_info in tensors.iter() { println!( - "{}: [{:?}; {:?}] {:?}", + "{}: [{:?}; {:?}]", tensor_info.name, tensor_info.layout.shape(), tensor_info.dtype, - tensor_info.path, - ) + ); + if verbose { + println!(" {:?}", tensor_info); + } } } Format::Pickle => { @@ -130,7 +132,7 @@ fn main() -> anyhow::Result<()> { if multiple_files { println!("--- {file:?} ---"); } - run_ls(file, format.clone())? + run_ls(file, format.clone(), args.verbose)? } } } diff --git a/candle-core/src/npy.rs b/candle-core/src/npy.rs index b7cfbda0..62c62f9e 100644 --- a/candle-core/src/npy.rs +++ b/candle-core/src/npy.rs @@ -196,7 +196,11 @@ impl Header { impl Tensor { // TODO: Add the possibility to read directly to a device? - fn from_reader(shape: Shape, dtype: DType, reader: &mut R) -> Result { + pub(crate) fn from_reader( + shape: Shape, + dtype: DType, + reader: &mut R, + ) -> Result { let elem_count = shape.elem_count(); match dtype { DType::BF16 => { diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index 8a69c1e9..f14a5046 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -1,7 +1,7 @@ // Just enough pickle support to be able to read PyTorch checkpoints. // This hardcodes objects that are required for tensor reading, we may want to make this a bit more // composable/tensor agnostic at some point. -use crate::{DType, Error as E, Layout, Result}; +use crate::{DType, Error as E, Layout, Result, Tensor}; use byteorder::{LittleEndian, ReadBytesExt}; use std::collections::HashMap; use std::io::BufRead; @@ -518,7 +518,7 @@ pub struct TensorInfo { pub name: String, pub dtype: DType, pub layout: Layout, - pub path: std::path::PathBuf, + pub path: String, } pub fn read_pth_tensor_info>(file: P) -> Result> { @@ -583,7 +583,7 @@ pub fn read_pth_tensor_info>(file: P) -> Result { @@ -595,3 +595,53 @@ pub fn read_pth_tensor_info>(file: P) -> Result, + path: std::path::PathBuf, + // We do not store a zip reader as it needs mutable access to extract data. Instead we + // re-create a zip reader for each tensor. +} + +impl PthTensors { + pub fn new>(path: P) -> Result { + let tensor_infos = read_pth_tensor_info(path.as_ref())?; + let tensor_infos = tensor_infos + .into_iter() + .map(|ti| (ti.name.to_string(), ti)) + .collect(); + let path = path.as_ref().to_owned(); + Ok(Self { tensor_infos, path }) + } + + pub fn tensor_infos(&self) -> &HashMap { + &self.tensor_infos + } + + pub fn get(&self, name: &str) -> Result> { + let tensor_info = match self.tensor_infos.get(name) { + None => return Ok(None), + Some(tensor_info) => tensor_info, + }; + // We hope that the file has not changed since first reading it. + let zip_reader = std::io::BufReader::new(std::fs::File::open(&self.path)?); + let mut zip = zip::ZipArchive::new(zip_reader)?; + let mut reader = zip.by_name(&tensor_info.path)?; + + // Reading the data is a bit tricky as it can be strided, use an offset, etc. + // For now only support the basic case. + if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() { + crate::bail!( + "cannot retrieve non-contiguous tensors {:?}", + tensor_info.layout + ) + } + let tensor = Tensor::from_reader( + tensor_info.layout.shape().clone(), + tensor_info.dtype, + &mut reader, + )?; + Ok(Some(tensor)) + } +}