mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Small tweaks to the pickle handling to be able to use libtorch files. (#530)
* Small tweaks to the pickle handling to be able to use libtorch files. * Move the pytorch specific bits in a different function.
This commit is contained in:
@ -40,6 +40,7 @@ pub enum OpCode {
|
|||||||
Dict = b'd',
|
Dict = b'd',
|
||||||
Build = b'b',
|
Build = b'b',
|
||||||
Stop = b'.',
|
Stop = b'.',
|
||||||
|
NewObj = 0x81,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Avoid using FromPrimitive so as not to drag another dependency.
|
// Avoid using FromPrimitive so as not to drag another dependency.
|
||||||
@ -74,6 +75,7 @@ impl TryFrom<u8> for OpCode {
|
|||||||
b'd' => Ok(Self::EmptyDict),
|
b'd' => Ok(Self::EmptyDict),
|
||||||
b'b' => Ok(Self::Build),
|
b'b' => Ok(Self::Build),
|
||||||
b'.' => Ok(Self::Stop),
|
b'.' => Ok(Self::Stop),
|
||||||
|
0x81 => Ok(Self::NewObj),
|
||||||
value => Err(value),
|
value => Err(value),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -106,6 +108,10 @@ pub enum Object {
|
|||||||
callable: Box<Object>,
|
callable: Box<Object>,
|
||||||
args: Box<Object>,
|
args: Box<Object>,
|
||||||
},
|
},
|
||||||
|
Build {
|
||||||
|
callable: Box<Object>,
|
||||||
|
args: Box<Object>,
|
||||||
|
},
|
||||||
PersistentLoad(Box<Object>),
|
PersistentLoad(Box<Object>),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -260,12 +266,19 @@ impl Stack {
|
|||||||
|
|
||||||
// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/#Pickle.OpCodes.BUILD
|
// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/#Pickle.OpCodes.BUILD
|
||||||
fn build(&mut self) -> Result<()> {
|
fn build(&mut self) -> Result<()> {
|
||||||
let mut args = self.pop()?;
|
let args = self.pop()?;
|
||||||
let obj = self.last()?;
|
let obj = self.pop()?;
|
||||||
match (obj, &mut args) {
|
let obj = match (obj, args) {
|
||||||
(Object::Dict(obj), Object::Dict(args)) => obj.append(args),
|
(Object::Dict(mut obj), Object::Dict(mut args)) => {
|
||||||
(obj, args) => println!("build {obj:?} {args:?}"),
|
obj.append(&mut args);
|
||||||
|
Object::Dict(obj)
|
||||||
}
|
}
|
||||||
|
(obj, args) => Object::Build {
|
||||||
|
callable: Box::new(obj),
|
||||||
|
args: Box::new(args),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
self.push(obj);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -322,6 +335,13 @@ impl Stack {
|
|||||||
Ok(Object::PersistentLoad(Box::new(id)))
|
Ok(Object::PersistentLoad(Box::new(id)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn new_obj(&self, class: Object, args: Object) -> Result<Object> {
|
||||||
|
Ok(Object::Reduce {
|
||||||
|
callable: Box::new(class),
|
||||||
|
args: Box::new(args),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn pop_to_marker(&mut self) -> Result<Vec<Object>> {
|
fn pop_to_marker(&mut self) -> Result<Vec<Object>> {
|
||||||
let mut mark_idx = None;
|
let mut mark_idx = None;
|
||||||
for (idx, obj) in self.stack.iter().enumerate().rev() {
|
for (idx, obj) in self.stack.iter().enumerate().rev() {
|
||||||
@ -477,6 +497,12 @@ impl Stack {
|
|||||||
let arg = r.read_u32::<LittleEndian>()?;
|
let arg = r.read_u32::<LittleEndian>()?;
|
||||||
self.memo_put(arg)?
|
self.memo_put(arg)?
|
||||||
}
|
}
|
||||||
|
OpCode::NewObj => {
|
||||||
|
let args = self.pop()?;
|
||||||
|
let class = self.pop()?;
|
||||||
|
let obj = self.new_obj(class, args)?;
|
||||||
|
self.push(obj)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Ok(false)
|
Ok(false)
|
||||||
}
|
}
|
||||||
@ -546,6 +572,19 @@ pub fn read_pth_tensor_info<P: AsRef<std::path::Path>>(file: P) -> Result<Vec<Te
|
|||||||
if VERBOSE {
|
if VERBOSE {
|
||||||
println!("{obj:?}");
|
println!("{obj:?}");
|
||||||
}
|
}
|
||||||
|
let obj = match obj {
|
||||||
|
Object::Build { callable, args } => match *callable {
|
||||||
|
Object::Reduce { callable, args: _ } => match *callable {
|
||||||
|
Object::Class {
|
||||||
|
module_name,
|
||||||
|
class_name,
|
||||||
|
} if module_name == "__torch__" && class_name == "Module" => *args,
|
||||||
|
_ => continue,
|
||||||
|
},
|
||||||
|
_ => continue,
|
||||||
|
},
|
||||||
|
obj => obj,
|
||||||
|
};
|
||||||
if let Object::Dict(key_values) = obj {
|
if let Object::Dict(key_values) = obj {
|
||||||
for (name, value) in key_values.into_iter() {
|
for (name, value) in key_values.into_iter() {
|
||||||
let name = match name.unicode() {
|
let name = match name.unicode() {
|
||||||
|
Reference in New Issue
Block a user