More pickle support. (#588)

* More pickle support.

* Be more verbose.
This commit is contained in:
Laurent Mazare
2023-08-24 18:45:10 +01:00
committed by GitHub
parent e21c686cdc
commit 2cde0cb74b
2 changed files with 40 additions and 4 deletions

View File

@ -91,7 +91,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
}
}
Format::Pth => {
let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?;
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose)?;
tensors.sort_by(|a, b| a.name.cmp(&b.name));
for tensor_info in tensors.iter() {
println!(

View File

@ -41,6 +41,10 @@ pub enum OpCode {
Build = b'b',
Stop = b'.',
NewObj = 0x81,
EmptyList = b']',
BinFloat = b'g',
Append = b'a',
Appends = b'e',
}
// Avoid using FromPrimitive so as not to drag another dependency.
@ -76,6 +80,10 @@ impl TryFrom<u8> for OpCode {
b'b' => Ok(Self::Build),
b'.' => Ok(Self::Stop),
0x81 => Ok(Self::NewObj),
b']' => Ok(Self::EmptyList),
b'G' => Ok(Self::BinFloat),
b'a' => Ok(Self::Append),
b'e' => Ok(Self::Appends),
value => Err(value),
}
}
@ -98,10 +106,12 @@ pub enum Object {
class_name: String,
},
Int(i32),
Float(f64),
Unicode(String),
Bool(bool),
None,
Tuple(Vec<Object>),
List(Vec<Object>),
Mark,
Dict(Vec<(Object, Object)>),
Reduce {
@ -400,6 +410,10 @@ impl Stack {
let arg = r.read_i32::<LittleEndian>()?;
self.push(Object::Int(arg))
}
OpCode::BinFloat => {
let arg = r.read_f64::<LittleEndian>()?;
self.push(Object::Float(arg))
}
OpCode::BinUnicode => {
let len = r.read_u32::<LittleEndian>()?;
let mut data = vec![0u8; len as usize];
@ -433,6 +447,24 @@ impl Stack {
}
OpCode::NewTrue => self.push(Object::Bool(true)),
OpCode::NewFalse => self.push(Object::Bool(false)),
OpCode::Append => {
let value = self.pop()?;
let pylist = self.last()?;
if let Object::List(d) = pylist {
d.push(value)
} else {
crate::bail!("expected a list, got {pylist:?}")
}
}
OpCode::Appends => {
let objs = self.pop_to_marker()?;
let pylist = self.last()?;
if let Object::List(d) = pylist {
d.extend(objs)
} else {
crate::bail!("expected a list, got {pylist:?}")
}
}
OpCode::SetItem => {
let value = self.pop()?;
let key = self.pop()?;
@ -479,6 +511,7 @@ impl Stack {
OpCode::Mark => self.push(Object::Mark),
OpCode::Reduce => self.reduce()?,
OpCode::EmptyTuple => self.push(Object::Tuple(vec![])),
OpCode::EmptyList => self.push(Object::List(vec![])),
OpCode::BinGet => {
let arg = r.read_u8()?;
let obj = self.memo_get(arg as u32)?;
@ -549,7 +582,10 @@ pub struct TensorInfo {
pub storage_size: usize,
}
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<TensorInfo>> {
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
file: P,
verbose: bool,
) -> Result<Vec<TensorInfo>> {
let file = std::fs::File::open(file)?;
let zip_reader = std::io::BufReader::new(file);
let mut zip = zip::ZipArchive::new(zip_reader)?;
@ -569,7 +605,7 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<Te
let mut stack = Stack::empty();
stack.read_loop(&mut reader)?;
let obj = stack.finalize()?;
if VERBOSE {
if VERBOSE || verbose {
println!("{obj:?}");
}
let obj = match obj {
@ -648,7 +684,7 @@ pub struct PthTensors {
impl PthTensors {
pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
let tensor_infos = read_pth_tensor_info(path.as_ref())?;
let tensor_infos = read_pth_tensor_info(path.as_ref(), false)?;
let tensor_infos = tensor_infos
.into_iter()
.map(|ti| (ti.name.to_string(), ti))