Small tweaks to the pickle handling to be able to use libtorch files. (#530)

* Small tweaks to the pickle handling to be able to use libtorch files.

* Move the pytorch specific bits in a different function.
This commit is contained in:
Laurent Mazare
2023-08-20 23:25:34 +01:00
committed by GitHub
parent 11c7e7bd67
commit 8c232d706b

View File

@ -40,6 +40,7 @@ pub enum OpCode {
Dict = b'd',
Build = b'b',
Stop = b'.',
NewObj = 0x81,
}
// Avoid using FromPrimitive so as not to drag another dependency.
@ -74,6 +75,7 @@ impl TryFrom<u8> for OpCode {
b'd' => Ok(Self::EmptyDict),
b'b' => Ok(Self::Build),
b'.' => Ok(Self::Stop),
0x81 => Ok(Self::NewObj),
value => Err(value),
}
}
@ -106,6 +108,10 @@ pub enum Object {
callable: Box<Object>,
args: Box<Object>,
},
Build {
callable: Box<Object>,
args: Box<Object>,
},
PersistentLoad(Box<Object>),
}
@ -260,12 +266,19 @@ impl Stack {
// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/#Pickle.OpCodes.BUILD
fn build(&mut self) -> Result<()> {
let mut args = self.pop()?;
let obj = self.last()?;
match (obj, &mut args) {
(Object::Dict(obj), Object::Dict(args)) => obj.append(args),
(obj, args) => println!("build {obj:?} {args:?}"),
}
let args = self.pop()?;
let obj = self.pop()?;
let obj = match (obj, args) {
(Object::Dict(mut obj), Object::Dict(mut args)) => {
obj.append(&mut args);
Object::Dict(obj)
}
(obj, args) => Object::Build {
callable: Box::new(obj),
args: Box::new(args),
},
};
self.push(obj);
Ok(())
}
@ -322,6 +335,13 @@ impl Stack {
Ok(Object::PersistentLoad(Box::new(id)))
}
fn new_obj(&self, class: Object, args: Object) -> Result<Object> {
Ok(Object::Reduce {
callable: Box::new(class),
args: Box::new(args),
})
}
fn pop_to_marker(&mut self) -> Result<Vec<Object>> {
let mut mark_idx = None;
for (idx, obj) in self.stack.iter().enumerate().rev() {
@ -477,6 +497,12 @@ impl Stack {
let arg = r.read_u32::<LittleEndian>()?;
self.memo_put(arg)?
}
OpCode::NewObj => {
let args = self.pop()?;
let class = self.pop()?;
let obj = self.new_obj(class, args)?;
self.push(obj)
}
}
Ok(false)
}
@ -546,6 +572,19 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<Te
if VERBOSE {
println!("{obj:?}");
}
let obj = match obj {
Object::Build { callable, args } => match *callable {
Object::Reduce { callable, args: _ } => match *callable {
Object::Class {
module_name,
class_name,
} if module_name == "__torch__" && class_name == "Module" => *args,
_ => continue,
},
_ => continue,
},
obj => obj,
};
if let Object::Dict(key_values) = obj {
for (name, value) in key_values.into_iter() {
let name = match name.unicode() {