mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Small tweaks to the pickle handling to be able to use libtorch files. (#530)
* Small tweaks to the pickle handling to be able to use libtorch files. * Move the pytorch specific bits in a different function.
This commit is contained in:
@ -40,6 +40,7 @@ pub enum OpCode {
|
||||
Dict = b'd',
|
||||
Build = b'b',
|
||||
Stop = b'.',
|
||||
NewObj = 0x81,
|
||||
}
|
||||
|
||||
// Avoid using FromPrimitive so as not to drag another dependency.
|
||||
@ -74,6 +75,7 @@ impl TryFrom<u8> for OpCode {
|
||||
b'd' => Ok(Self::EmptyDict),
|
||||
b'b' => Ok(Self::Build),
|
||||
b'.' => Ok(Self::Stop),
|
||||
0x81 => Ok(Self::NewObj),
|
||||
value => Err(value),
|
||||
}
|
||||
}
|
||||
@ -106,6 +108,10 @@ pub enum Object {
|
||||
callable: Box<Object>,
|
||||
args: Box<Object>,
|
||||
},
|
||||
Build {
|
||||
callable: Box<Object>,
|
||||
args: Box<Object>,
|
||||
},
|
||||
PersistentLoad(Box<Object>),
|
||||
}
|
||||
|
||||
@ -260,12 +266,19 @@ impl Stack {
|
||||
|
||||
// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/#Pickle.OpCodes.BUILD
|
||||
fn build(&mut self) -> Result<()> {
|
||||
let mut args = self.pop()?;
|
||||
let obj = self.last()?;
|
||||
match (obj, &mut args) {
|
||||
(Object::Dict(obj), Object::Dict(args)) => obj.append(args),
|
||||
(obj, args) => println!("build {obj:?} {args:?}"),
|
||||
}
|
||||
let args = self.pop()?;
|
||||
let obj = self.pop()?;
|
||||
let obj = match (obj, args) {
|
||||
(Object::Dict(mut obj), Object::Dict(mut args)) => {
|
||||
obj.append(&mut args);
|
||||
Object::Dict(obj)
|
||||
}
|
||||
(obj, args) => Object::Build {
|
||||
callable: Box::new(obj),
|
||||
args: Box::new(args),
|
||||
},
|
||||
};
|
||||
self.push(obj);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -322,6 +335,13 @@ impl Stack {
|
||||
Ok(Object::PersistentLoad(Box::new(id)))
|
||||
}
|
||||
|
||||
fn new_obj(&self, class: Object, args: Object) -> Result<Object> {
|
||||
Ok(Object::Reduce {
|
||||
callable: Box::new(class),
|
||||
args: Box::new(args),
|
||||
})
|
||||
}
|
||||
|
||||
fn pop_to_marker(&mut self) -> Result<Vec<Object>> {
|
||||
let mut mark_idx = None;
|
||||
for (idx, obj) in self.stack.iter().enumerate().rev() {
|
||||
@ -477,6 +497,12 @@ impl Stack {
|
||||
let arg = r.read_u32::<LittleEndian>()?;
|
||||
self.memo_put(arg)?
|
||||
}
|
||||
OpCode::NewObj => {
|
||||
let args = self.pop()?;
|
||||
let class = self.pop()?;
|
||||
let obj = self.new_obj(class, args)?;
|
||||
self.push(obj)
|
||||
}
|
||||
}
|
||||
Ok(false)
|
||||
}
|
||||
@ -546,6 +572,19 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<Te
|
||||
if VERBOSE {
|
||||
println!("{obj:?}");
|
||||
}
|
||||
let obj = match obj {
|
||||
Object::Build { callable, args } => match *callable {
|
||||
Object::Reduce { callable, args: _ } => match *callable {
|
||||
Object::Class {
|
||||
module_name,
|
||||
class_name,
|
||||
} if module_name == "__torch__" && class_name == "Module" => *args,
|
||||
_ => continue,
|
||||
},
|
||||
_ => continue,
|
||||
},
|
||||
obj => obj,
|
||||
};
|
||||
if let Object::Dict(key_values) = obj {
|
||||
for (name, value) in key_values.into_iter() {
|
||||
let name = match name.unicode() {
|
||||
|
Reference in New Issue
Block a user