More pickle support. (#588)

* More pickle support.

* Be more verbose.
This commit is contained in:
Laurent Mazare
2023-08-24 18:45:10 +01:00
committed by GitHub
parent e21c686cdc
commit 2cde0cb74b
2 changed files with 40 additions and 4 deletions

View File

@ -91,7 +91,7 @@ fn run_ls(file: &std::path::PathBuf, format: Option<Format>, verbose: bool) -> R
} }
} }
Format::Pth => { Format::Pth => {
let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?; let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose)?;
tensors.sort_by(|a, b| a.name.cmp(&b.name)); tensors.sort_by(|a, b| a.name.cmp(&b.name));
for tensor_info in tensors.iter() { for tensor_info in tensors.iter() {
println!( println!(

View File

@ -41,6 +41,10 @@ pub enum OpCode {
Build = b'b', Build = b'b',
Stop = b'.', Stop = b'.',
NewObj = 0x81, NewObj = 0x81,
EmptyList = b']',
BinFloat = b'g',
Append = b'a',
Appends = b'e',
} }
// Avoid using FromPrimitive so as not to drag another dependency. // Avoid using FromPrimitive so as not to drag another dependency.
@ -76,6 +80,10 @@ impl TryFrom<u8> for OpCode {
b'b' => Ok(Self::Build), b'b' => Ok(Self::Build),
b'.' => Ok(Self::Stop), b'.' => Ok(Self::Stop),
0x81 => Ok(Self::NewObj), 0x81 => Ok(Self::NewObj),
b']' => Ok(Self::EmptyList),
b'G' => Ok(Self::BinFloat),
b'a' => Ok(Self::Append),
b'e' => Ok(Self::Appends),
value => Err(value), value => Err(value),
} }
} }
@ -98,10 +106,12 @@ pub enum Object {
class_name: String, class_name: String,
}, },
Int(i32), Int(i32),
Float(f64),
Unicode(String), Unicode(String),
Bool(bool), Bool(bool),
None, None,
Tuple(Vec<Object>), Tuple(Vec<Object>),
List(Vec<Object>),
Mark, Mark,
Dict(Vec<(Object, Object)>), Dict(Vec<(Object, Object)>),
Reduce { Reduce {
@ -400,6 +410,10 @@ impl Stack {
let arg = r.read_i32::<LittleEndian>()?; let arg = r.read_i32::<LittleEndian>()?;
self.push(Object::Int(arg)) self.push(Object::Int(arg))
} }
OpCode::BinFloat => {
let arg = r.read_f64::<LittleEndian>()?;
self.push(Object::Float(arg))
}
OpCode::BinUnicode => { OpCode::BinUnicode => {
let len = r.read_u32::<LittleEndian>()?; let len = r.read_u32::<LittleEndian>()?;
let mut data = vec![0u8; len as usize]; let mut data = vec![0u8; len as usize];
@ -433,6 +447,24 @@ impl Stack {
} }
OpCode::NewTrue => self.push(Object::Bool(true)), OpCode::NewTrue => self.push(Object::Bool(true)),
OpCode::NewFalse => self.push(Object::Bool(false)), OpCode::NewFalse => self.push(Object::Bool(false)),
OpCode::Append => {
let value = self.pop()?;
let pylist = self.last()?;
if let Object::List(d) = pylist {
d.push(value)
} else {
crate::bail!("expected a list, got {pylist:?}")
}
}
OpCode::Appends => {
let objs = self.pop_to_marker()?;
let pylist = self.last()?;
if let Object::List(d) = pylist {
d.extend(objs)
} else {
crate::bail!("expected a list, got {pylist:?}")
}
}
OpCode::SetItem => { OpCode::SetItem => {
let value = self.pop()?; let value = self.pop()?;
let key = self.pop()?; let key = self.pop()?;
@ -479,6 +511,7 @@ impl Stack {
OpCode::Mark => self.push(Object::Mark), OpCode::Mark => self.push(Object::Mark),
OpCode::Reduce => self.reduce()?, OpCode::Reduce => self.reduce()?,
OpCode::EmptyTuple => self.push(Object::Tuple(vec![])), OpCode::EmptyTuple => self.push(Object::Tuple(vec![])),
OpCode::EmptyList => self.push(Object::List(vec![])),
OpCode::BinGet => { OpCode::BinGet => {
let arg = r.read_u8()?; let arg = r.read_u8()?;
let obj = self.memo_get(arg as u32)?; let obj = self.memo_get(arg as u32)?;
@ -549,7 +582,10 @@ pub struct TensorInfo {
pub storage_size: usize, pub storage_size: usize,
} }
pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<TensorInfo>> { pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(
file: P,
verbose: bool,
) -> Result<Vec<TensorInfo>> {
let file = std::fs::File::open(file)?; let file = std::fs::File::open(file)?;
let zip_reader = std::io::BufReader::new(file); let zip_reader = std::io::BufReader::new(file);
let mut zip = zip::ZipArchive::new(zip_reader)?; let mut zip = zip::ZipArchive::new(zip_reader)?;
@ -569,7 +605,7 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<Te
let mut stack = Stack::empty(); let mut stack = Stack::empty();
stack.read_loop(&mut reader)?; stack.read_loop(&mut reader)?;
let obj = stack.finalize()?; let obj = stack.finalize()?;
if VERBOSE { if VERBOSE || verbose {
println!("{obj:?}"); println!("{obj:?}");
} }
let obj = match obj { let obj = match obj {
@ -648,7 +684,7 @@ pub struct PthTensors {
impl PthTensors { impl PthTensors {
pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> { pub fn new<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
let tensor_infos = read_pth_tensor_info(path.as_ref())?; let tensor_infos = read_pth_tensor_info(path.as_ref(), false)?;
let tensor_infos = tensor_infos let tensor_infos = tensor_infos
.into_iter() .into_iter()
.map(|ti| (ti.name.to_string(), ti)) .map(|ti| (ti.name.to_string(), ti))