From 2cde0cb74be3ada6b2d845a018d4746929318afb Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Thu, 24 Aug 2023 18:45:10 +0100 Subject: [PATCH] More pickle support. (#588) * More pickle support. * Be more verbose. --- candle-core/examples/tensor-tools.rs | 2 +- candle-core/src/pickle.rs | 42 ++++++++++++++++++++++++++-- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs index d5f7dd57..67e6aa1e 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/candle-core/examples/tensor-tools.rs @@ -91,7 +91,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option, verbose: bool) -> R } } Format::Pth => { - let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?; + let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose)?; tensors.sort_by(|a, b| a.name.cmp(&b.name)); for tensor_info in tensors.iter() { println!( diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index 8562b2a9..37c15018 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -41,6 +41,10 @@ pub enum OpCode { Build = b'b', Stop = b'.', NewObj = 0x81, + EmptyList = b']', + BinFloat = b'g', + Append = b'a', + Appends = b'e', } // Avoid using FromPrimitive so as not to drag another dependency. @@ -76,6 +80,10 @@ impl TryFrom for OpCode { b'b' => Ok(Self::Build), b'.' => Ok(Self::Stop), 0x81 => Ok(Self::NewObj), + b']' => Ok(Self::EmptyList), + b'G' => Ok(Self::BinFloat), + b'a' => Ok(Self::Append), + b'e' => Ok(Self::Appends), value => Err(value), } } @@ -98,10 +106,12 @@ pub enum Object { class_name: String, }, Int(i32), + Float(f64), Unicode(String), Bool(bool), None, Tuple(Vec), + List(Vec), Mark, Dict(Vec<(Object, Object)>), Reduce { @@ -400,6 +410,10 @@ impl Stack { let arg = r.read_i32::()?; self.push(Object::Int(arg)) } + OpCode::BinFloat => { + let arg = r.read_f64::()?; + self.push(Object::Float(arg)) + } OpCode::BinUnicode => { let len = r.read_u32::()?; let mut data = vec![0u8; len as usize]; @@ -433,6 +447,24 @@ impl Stack { } OpCode::NewTrue => self.push(Object::Bool(true)), OpCode::NewFalse => self.push(Object::Bool(false)), + OpCode::Append => { + let value = self.pop()?; + let pylist = self.last()?; + if let Object::List(d) = pylist { + d.push(value) + } else { + crate::bail!("expected a list, got {pylist:?}") + } + } + OpCode::Appends => { + let objs = self.pop_to_marker()?; + let pylist = self.last()?; + if let Object::List(d) = pylist { + d.extend(objs) + } else { + crate::bail!("expected a list, got {pylist:?}") + } + } OpCode::SetItem => { let value = self.pop()?; let key = self.pop()?; @@ -479,6 +511,7 @@ impl Stack { OpCode::Mark => self.push(Object::Mark), OpCode::Reduce => self.reduce()?, OpCode::EmptyTuple => self.push(Object::Tuple(vec![])), + OpCode::EmptyList => self.push(Object::List(vec![])), OpCode::BinGet => { let arg = r.read_u8()?; let obj = self.memo_get(arg as u32)?; @@ -549,7 +582,10 @@ pub struct TensorInfo { pub storage_size: usize, } -pub fn read_pth_tensor_info>(file: P) -> Result> { +pub fn read_pth_tensor_info>( + file: P, + verbose: bool, +) -> Result> { let file = std::fs::File::open(file)?; let zip_reader = std::io::BufReader::new(file); let mut zip = zip::ZipArchive::new(zip_reader)?; @@ -569,7 +605,7 @@ pub fn read_pth_tensor_info>(file: P) -> Result>(path: P) -> Result { - let tensor_infos = read_pth_tensor_info(path.as_ref())?; + let tensor_infos = read_pth_tensor_info(path.as_ref(), false)?; let tensor_infos = tensor_infos .into_iter() .map(|ti| (ti.name.to_string(), ti))