From ad33715c61c0e5d7a90889618381df6b2a306799 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 19 Aug 2023 11:26:32 +0100 Subject: [PATCH] Preliminary support for importing PyTorch weights. (#511) * Pickle work-in-progress. * More unpickling. * More pickling. * Proper handling of setitems. * Clippy. * Again more pickling. * Restore the example. * Add enough pickle support to get the list of tensors. * Read the data from zip files. * Retrieve the tensor shape. * Extract the size and dtype. * More storage types. * Improve the destructuring. --- candle-core/examples/tensor-tools.rs | 16 + candle-core/src/lib.rs | 1 + candle-core/src/pickle.rs | 556 +++++++++++++++++++++++++++ 3 files changed, 573 insertions(+) create mode 100644 candle-core/src/pickle.rs diff --git a/candle-core/examples/tensor-tools.rs b/candle-core/examples/tensor-tools.rs index 03ea923b..7baee582 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/candle-core/examples/tensor-tools.rs @@ -45,6 +45,22 @@ fn run_ls(file: &std::path::PathBuf) -> Result<()> { println!("{name}: [{shape:?}; {dtype}]") } } + Some("pt") | Some("pth") => { + let mut tensors = candle_core::pickle::read_pth_tensor_info(file)?; + tensors.sort_by(|a, b| a.0.cmp(&b.0)); + for (name, dtype, shape) in tensors.iter() { + println!("{name}: [{shape:?}; {dtype:?}]") + } + } + Some("pkl") => { + let file = std::fs::File::open(file)?; + let mut reader = std::io::BufReader::new(file); + let mut stack = candle_core::pickle::Stack::empty(); + stack.read_loop(&mut reader)?; + for (i, obj) in stack.stack().iter().enumerate() { + println!("{i} {obj:?}"); + } + } Some(_) => { println!("{file:?}: unsupported file extension") } diff --git a/candle-core/src/lib.rs b/candle-core/src/lib.rs index 62ad55d1..3622d22e 100644 --- a/candle-core/src/lib.rs +++ b/candle-core/src/lib.rs @@ -56,6 +56,7 @@ pub mod layout; mod mkl; pub mod npy; mod op; +pub mod pickle; pub mod quantized; pub mod safetensors; pub mod shape; diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs new file mode 100644 index 00000000..059a0d9c --- /dev/null +++ b/candle-core/src/pickle.rs @@ -0,0 +1,556 @@ +// Just enough pickle support to be able to read PyTorch checkpoints. +// This hardcodes objects that are required for tensor reading, we may want to make this a bit more +// composable/tensor agnostic at some point. +use crate::{DType, Error as E, Result}; +use byteorder::{LittleEndian, ReadBytesExt}; +use std::collections::HashMap; +use std::io::BufRead; + +// https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/ +#[repr(u8)] +#[derive(Debug, Eq, PartialEq, Clone)] +pub enum OpCode { + // https://github.com/python/cpython/blob/ed25f097160b5cbb0c9a1f9a746d2f1bbc96515a/Lib/pickletools.py#L2123 + Proto = 0x80, + Global = b'c', + BinPut = b'q', + LongBinPut = b'r', + EmptyTuple = b')', + Reduce = b'R', + Mark = b'(', + BinUnicode = b'X', + BinInt = b'J', + Tuple = b't', + BinPersId = b'Q', + BinInt1 = b'K', + BinInt2 = b'M', + Tuple1 = 0x85, + Tuple2 = 0x86, + Tuple3 = 0x87, + NewTrue = 0x88, + NewFalse = 0x89, + None = b'N', + BinGet = b'h', + LongBinGet = b'j', + SetItem = b's', + SetItems = b'u', + EmptyDict = b'}', + Dict = b'd', + Build = b'b', + Stop = b'.', +} + +// Avoid using FromPrimitive so as not to drag another dependency. +impl TryFrom for OpCode { + type Error = u8; + fn try_from(value: u8) -> std::result::Result { + match value { + 0x80 => Ok(Self::Proto), + b'c' => Ok(Self::Global), + b'q' => Ok(Self::BinPut), + b'r' => Ok(Self::LongBinPut), + b')' => Ok(Self::EmptyTuple), + b'R' => Ok(Self::Reduce), + b'(' => Ok(Self::Mark), + b'X' => Ok(Self::BinUnicode), + b'J' => Ok(Self::BinInt), + b't' => Ok(Self::Tuple), + b'Q' => Ok(Self::BinPersId), + b'K' => Ok(Self::BinInt1), + b'M' => Ok(Self::BinInt2), + b'N' => Ok(Self::None), + 0x85 => Ok(Self::Tuple1), + 0x86 => Ok(Self::Tuple2), + 0x87 => Ok(Self::Tuple3), + 0x88 => Ok(Self::NewTrue), + 0x89 => Ok(Self::NewFalse), + b'h' => Ok(Self::BinGet), + b'j' => Ok(Self::LongBinGet), + b's' => Ok(Self::SetItem), + b'u' => Ok(Self::SetItems), + b'}' => Ok(Self::EmptyDict), + b'd' => Ok(Self::EmptyDict), + b'b' => Ok(Self::Build), + b'.' => Ok(Self::Stop), + value => Err(value), + } + } +} + +fn read_to_newline(r: &mut R) -> Result> { + let mut data: Vec = Vec::with_capacity(32); + r.read_until(b'\n', &mut data)?; + data.pop(); + if data.last() == Some(&b'\r') { + data.pop(); + } + Ok(data) +} + +#[derive(Debug, Clone, PartialEq)] +pub enum Object { + Class { + module_name: String, + class_name: String, + }, + Int(i32), + Unicode(String), + Bool(bool), + None, + Tuple(Vec), + Mark, + Dict(Vec<(Object, Object)>), + Reduce { + callable: Box, + args: Box, + }, + PersistentLoad(Box), +} + +type OResult = std::result::Result; + +impl Object { + pub fn unicode(self) -> OResult { + match self { + Self::Unicode(t) => Ok(t), + _ => Err(self), + } + } + + pub fn reduce(self) -> OResult<(Self, Self)> { + match self { + Self::Reduce { callable, args } => Ok((*callable, *args)), + _ => Err(self), + } + } + + pub fn none(self) -> OResult<()> { + match self { + Self::None => Ok(()), + _ => Err(self), + } + } + + pub fn persistent_load(self) -> OResult { + match self { + Self::PersistentLoad(t) => Ok(*t), + _ => Err(self), + } + } + + pub fn bool(self) -> OResult { + match self { + Self::Bool(t) => Ok(t), + _ => Err(self), + } + } + + pub fn int(self) -> OResult { + match self { + Self::Int(t) => Ok(t), + _ => Err(self), + } + } + + pub fn tuple(self) -> OResult> { + match self { + Self::Tuple(t) => Ok(t), + _ => Err(self), + } + } + + pub fn dict(self) -> OResult> { + match self { + Self::Dict(t) => Ok(t), + _ => Err(self), + } + } + + pub fn class(self) -> OResult<(String, String)> { + match self { + Self::Class { + module_name, + class_name, + } => Ok((module_name, class_name)), + _ => Err(self), + } + } +} + +impl TryFrom for String { + type Error = Object; + fn try_from(value: Object) -> std::result::Result { + match value { + Object::Unicode(s) => Ok(s), + other => Err(other), + } + } +} + +impl TryFrom for usize { + type Error = Object; + fn try_from(value: Object) -> std::result::Result { + match value { + Object::Int(s) if s >= 0 => Ok(s as usize), + other => Err(other), + } + } +} + +impl> TryFrom for Vec { + type Error = Object; + fn try_from(value: Object) -> std::result::Result { + match value { + Object::Tuple(values) => { + // This does not return the appropriate value in the error case but instead return + // the object related to the first error. + values + .into_iter() + .map(|v| T::try_from(v)) + .collect::, Self::Error>>() + } + other => Err(other), + } + } +} + +#[derive(Debug)] +pub struct Stack { + stack: Vec, + memo: HashMap, +} + +impl Stack { + pub fn empty() -> Self { + Self { + stack: Vec::with_capacity(512), + memo: HashMap::new(), + } + } + + pub fn stack(&self) -> &[Object] { + self.stack.as_slice() + } + + pub fn read_loop(&mut self, r: &mut R) -> Result<()> { + loop { + if self.read(r)? { + break; + } + } + Ok(()) + } + + pub fn finalize(mut self) -> Result { + self.pop() + } + + fn push(&mut self, obj: Object) { + self.stack.push(obj) + } + + fn pop(&mut self) -> Result { + match self.stack.pop() { + None => crate::bail!("unexpected empty stack"), + Some(obj) => Ok(obj), + } + } + + // https://docs.juliahub.com/Pickle/LAUNc/0.1.0/opcode/#Pickle.OpCodes.BUILD + fn build(&mut self) -> Result<()> { + let mut args = self.pop()?; + let obj = self.last()?; + match (obj, &mut args) { + (Object::Dict(obj), Object::Dict(args)) => obj.append(args), + (obj, args) => println!("build {obj:?} {args:?}"), + } + Ok(()) + } + + fn reduce(&mut self) -> Result<()> { + let args = self.pop()?; + let callable = self.pop()?; + #[allow(clippy::single_match)] + let reduced = match &callable { + Object::Class { + module_name, + class_name, + } => { + if module_name == "collections" && class_name == "OrderedDict" { + // TODO: have a separate ordered dict. + Some(Object::Dict(vec![])) + } else { + None + } + } + _ => None, + }; + let reduced = reduced.unwrap_or_else(|| Object::Reduce { + callable: Box::new(callable), + args: Box::new(args), + }); + self.push(reduced); + Ok(()) + } + + fn last(&mut self) -> Result<&mut Object> { + match self.stack.last_mut() { + None => crate::bail!("unexpected empty stack"), + Some(obj) => Ok(obj), + } + } + + fn memo_get(&self, id: u32) -> Result { + match self.memo.get(&id) { + None => crate::bail!("missing object in memo {id}"), + Some(obj) => { + // Maybe we should use refcounting rather than doing potential large clones here. + Ok(obj.clone()) + } + } + } + + fn memo_put(&mut self, id: u32) -> Result<()> { + let obj = self.last()?.clone(); + self.memo.insert(id, obj); + Ok(()) + } + + fn persistent_load(&self, id: Object) -> Result { + Ok(Object::PersistentLoad(Box::new(id))) + } + + fn pop_to_marker(&mut self) -> Result> { + let mut mark_idx = None; + for (idx, obj) in self.stack.iter().enumerate().rev() { + if obj == &Object::Mark { + mark_idx = Some(idx); + break; + } + } + match mark_idx { + Some(mark_idx) => { + let objs = self.stack.split_off(mark_idx + 1); + self.stack.pop(); + Ok(objs) + } + None => { + crate::bail!("marker object not found") + } + } + } + + pub fn read(&mut self, r: &mut R) -> Result { + let op_code = match OpCode::try_from(r.read_u8()?) { + Ok(op_code) => op_code, + Err(op_code) => { + crate::bail!("unknown op-code {op_code}") + } + }; + // println!("op: {op_code:?}"); + // println!("{:?}", self.stack); + match op_code { + OpCode::Proto => { + let version = r.read_u8()?; + println!("proto {version}"); + } + OpCode::Global => { + let module_name = read_to_newline(r)?; + let class_name = read_to_newline(r)?; + let module_name = String::from_utf8_lossy(&module_name).to_string(); + let class_name = String::from_utf8_lossy(&class_name).to_string(); + self.push(Object::Class { + module_name, + class_name, + }) + } + OpCode::BinInt1 => { + let arg = r.read_u8()?; + self.push(Object::Int(arg as i32)) + } + OpCode::BinInt2 => { + let arg = r.read_u16::()?; + self.push(Object::Int(arg as i32)) + } + OpCode::BinInt => { + let arg = r.read_i32::()?; + self.push(Object::Int(arg)) + } + OpCode::BinUnicode => { + let len = r.read_u32::()?; + let mut data = vec![0u8; len as usize]; + r.read_exact(&mut data)?; + let data = String::from_utf8(data).map_err(E::wrap)?; + self.push(Object::Unicode(data)) + } + OpCode::BinPersId => { + let id = self.pop()?; + let obj = self.persistent_load(id)?; + self.push(obj) + } + OpCode::Tuple => { + let objs = self.pop_to_marker()?; + self.push(Object::Tuple(objs)) + } + OpCode::Tuple1 => { + let obj = self.pop()?; + self.push(Object::Tuple(vec![obj])) + } + OpCode::Tuple2 => { + let obj2 = self.pop()?; + let obj1 = self.pop()?; + self.push(Object::Tuple(vec![obj1, obj2])) + } + OpCode::Tuple3 => { + let obj3 = self.pop()?; + let obj2 = self.pop()?; + let obj1 = self.pop()?; + self.push(Object::Tuple(vec![obj1, obj2, obj3])) + } + OpCode::NewTrue => self.push(Object::Bool(true)), + OpCode::NewFalse => self.push(Object::Bool(false)), + OpCode::SetItem => { + let value = self.pop()?; + let key = self.pop()?; + let pydict = self.last()?; + if let Object::Dict(d) = pydict { + d.push((key, value)) + } else { + crate::bail!("expected a dict, got {pydict:?}") + } + } + OpCode::SetItems => { + let mut objs = self.pop_to_marker()?; + let pydict = self.last()?; + if let Object::Dict(d) = pydict { + if objs.len() % 2 != 0 { + crate::bail!("setitems: not an even number of objects") + } + while let Some(value) = objs.pop() { + let key = objs.pop().unwrap(); + d.push((key, value)) + } + } else { + crate::bail!("expected a dict, got {pydict:?}") + } + } + OpCode::None => self.push(Object::None), + OpCode::Stop => { + return Ok(true); + } + OpCode::Build => self.build()?, + OpCode::EmptyDict => self.push(Object::Dict(vec![])), + OpCode::Dict => { + let mut objs = self.pop_to_marker()?; + let mut pydict = vec![]; + if objs.len() % 2 != 0 { + crate::bail!("setitems: not an even number of objects") + } + while let Some(value) = objs.pop() { + let key = objs.pop().unwrap(); + pydict.push((key, value)) + } + self.push(Object::Dict(pydict)) + } + OpCode::Mark => self.push(Object::Mark), + OpCode::Reduce => self.reduce()?, + OpCode::EmptyTuple => self.push(Object::Tuple(vec![])), + OpCode::BinGet => { + let arg = r.read_u8()?; + let obj = self.memo_get(arg as u32)?; + self.push(obj) + } + OpCode::LongBinGet => { + let arg = r.read_u32::()?; + let obj = self.memo_get(arg)?; + self.push(obj) + } + OpCode::BinPut => { + let arg = r.read_u8()?; + self.memo_put(arg as u32)? + } + OpCode::LongBinPut => { + let arg = r.read_u32::()?; + self.memo_put(arg)? + } + } + Ok(false) + } +} + +impl From for E { + fn from(value: Object) -> Self { + E::Msg(format!("conversion error on {value:?}")) + } +} + +// https://github.com/pytorch/pytorch/blob/4eac43d046ded0f0a5a5fa8db03eb40f45bf656e/torch/_utils.py#L198 +// Arguments: storage, storage_offset, size, stride, requires_grad, backward_hooks +fn rebuild_args(args: Object) -> Result<(Vec, DType)> { + let mut args = args.tuple()?; + let size = Vec::::try_from(args.remove(2))?; + let storage = args.remove(0).persistent_load()?; + let mut storage = storage.tuple()?; + let (_module_name, class_name) = storage.remove(1).class()?; + let dtype = match class_name.as_str() { + "FloatStorage" => DType::F32, + "DoubleStorage" => DType::F64, + "HalfStorage" => DType::F16, + "BFloat16Storage" => DType::BF16, + "ByteStorage" => DType::U8, + other => { + crate::bail!("unsupported storage type {other}") + } + }; + Ok((size, dtype)) +} + +pub fn read_pth_tensor_info>( + file: P, +) -> Result)>> { + let file = std::fs::File::open(file)?; + let zip_reader = std::io::BufReader::new(file); + let mut zip = zip::ZipArchive::new(zip_reader)?; + let zip_file_names = zip + .file_names() + .map(|f| f.to_string()) + .collect::>(); + + let mut tensor_info = vec![]; + for name in zip_file_names.iter() { + if !name.ends_with("data.pkl") { + continue; + } + let reader = zip.by_name(name)?; + let mut reader = std::io::BufReader::new(reader); + let mut stack = Stack::empty(); + stack.read_loop(&mut reader)?; + let obj = stack.finalize()?; + if let Object::Dict(key_values) = obj { + for (key, value) in key_values.into_iter() { + let key = match key.unicode() { + Ok(key) => key, + Err(_) => continue, + }; + let (callable, args) = match value.reduce() { + Ok(callable_args) => callable_args, + _ => continue, + }; + match callable { + Object::Class { + module_name, + class_name, + } if module_name == "torch._utils" && class_name == "_rebuild_tensor_v2" => {} + _ => continue, + }; + match rebuild_args(args) { + Ok((size, dtype)) => tensor_info.push((key, dtype, size)), + Err(err) => { + eprintln!("skipping {key}: {err:?}") + } + } + } + } + } + Ok(tensor_info) +}