diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index e913935c..8562b2a9 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -40,6 +40,7 @@ pub enum OpCode { Dict = b'd', Build = b'b', Stop = b'.', + NewObj = 0x81, } // Avoid using FromPrimitive so as not to drag another dependency. @@ -74,6 +75,7 @@ impl TryFrom for OpCode { b'd' => Ok(Self::EmptyDict), b'b' => Ok(Self::Build), b'.' => Ok(Self::Stop), + 0x81 => Ok(Self::NewObj), value => Err(value), } } @@ -106,6 +108,10 @@ pub enum Object { callable: Box, args: Box, }, + Build { + callable: Box, + args: Box, + }, PersistentLoad(Box), } @@ -260,12 +266,19 @@ impl Stack { // https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/#Pickle.OpCodes.BUILD fn build(&mut self) -> Result<()> { - let mut args = self.pop()?; - let obj = self.last()?; - match (obj, &mut args) { - (Object::Dict(obj), Object::Dict(args)) => obj.append(args), - (obj, args) => println!("build {obj:?} {args:?}"), - } + let args = self.pop()?; + let obj = self.pop()?; + let obj = match (obj, args) { + (Object::Dict(mut obj), Object::Dict(mut args)) => { + obj.append(&mut args); + Object::Dict(obj) + } + (obj, args) => Object::Build { + callable: Box::new(obj), + args: Box::new(args), + }, + }; + self.push(obj); Ok(()) } @@ -322,6 +335,13 @@ impl Stack { Ok(Object::PersistentLoad(Box::new(id))) } + fn new_obj(&self, class: Object, args: Object) -> Result { + Ok(Object::Reduce { + callable: Box::new(class), + args: Box::new(args), + }) + } + fn pop_to_marker(&mut self) -> Result> { let mut mark_idx = None; for (idx, obj) in self.stack.iter().enumerate().rev() { @@ -477,6 +497,12 @@ impl Stack { let arg = r.read_u32::()?; self.memo_put(arg)? } + OpCode::NewObj => { + let args = self.pop()?; + let class = self.pop()?; + let obj = self.new_obj(class, args)?; + self.push(obj) + } } Ok(false) } @@ -546,6 +572,19 @@ pub fn read_pth_tensor_info>(file: P) -> Result match *callable { + Object::Reduce { callable, args: _ } => match *callable { + Object::Class { + module_name, + class_name, + } if module_name == "__torch__" && class_name == "Module" => *args, + _ => continue, + }, + _ => continue, + }, + obj => obj, + }; if let Object::Dict(key_values) = obj { for (name, value) in key_values.into_iter() { let name = match name.unicode() {