mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 02:38:10 +00:00
Pickle decoder fix and Long1 opcode addition. (#2824)
* Pickle decoder changes: added Long1 opcode, fixed tensor offset calculation * Apply rustfmt. --------- Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
@ -45,6 +45,7 @@ pub enum OpCode {
|
|||||||
BinFloat = b'G',
|
BinFloat = b'G',
|
||||||
Append = b'a',
|
Append = b'a',
|
||||||
Appends = b'e',
|
Appends = b'e',
|
||||||
|
Long1 = 0x8a,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Avoid using FromPrimitive so as not to drag another dependency.
|
// Avoid using FromPrimitive so as not to drag another dependency.
|
||||||
@ -84,6 +85,7 @@ impl TryFrom<u8> for OpCode {
|
|||||||
b'G' => Ok(Self::BinFloat),
|
b'G' => Ok(Self::BinFloat),
|
||||||
b'a' => Ok(Self::Append),
|
b'a' => Ok(Self::Append),
|
||||||
b'e' => Ok(Self::Appends),
|
b'e' => Ok(Self::Appends),
|
||||||
|
0x8a => Ok(Self::Long1),
|
||||||
value => Err(value),
|
value => Err(value),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -106,6 +108,7 @@ pub enum Object {
|
|||||||
class_name: String,
|
class_name: String,
|
||||||
},
|
},
|
||||||
Int(i32),
|
Int(i32),
|
||||||
|
Long(i64),
|
||||||
Float(f64),
|
Float(f64),
|
||||||
Unicode(String),
|
Unicode(String),
|
||||||
Bool(bool),
|
Bool(bool),
|
||||||
@ -170,6 +173,14 @@ impl Object {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn int_or_long(self) -> OResult<i64> {
|
||||||
|
match self {
|
||||||
|
Self::Int(t) => Ok(t as i64),
|
||||||
|
Self::Long(t) => Ok(t),
|
||||||
|
_ => Err(self),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn tuple(self) -> OResult<Vec<Self>> {
|
pub fn tuple(self) -> OResult<Vec<Self>> {
|
||||||
match self {
|
match self {
|
||||||
Self::Tuple(t) => Ok(t),
|
Self::Tuple(t) => Ok(t),
|
||||||
@ -590,6 +601,15 @@ impl Stack {
|
|||||||
let obj = self.new_obj(class, args)?;
|
let obj = self.new_obj(class, args)?;
|
||||||
self.push(obj)
|
self.push(obj)
|
||||||
}
|
}
|
||||||
|
OpCode::Long1 => {
|
||||||
|
let n_bytes = r.read_u8()?;
|
||||||
|
let mut v = 0;
|
||||||
|
// Decode the next n bytes in little endian
|
||||||
|
for i in 0..n_bytes {
|
||||||
|
v |= (r.read_u8()? as i64) << (i * 8);
|
||||||
|
}
|
||||||
|
self.push(Object::Long(v))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Ok(false)
|
Ok(false)
|
||||||
}
|
}
|
||||||
@ -607,10 +627,10 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
|
|||||||
let mut args = args.tuple()?;
|
let mut args = args.tuple()?;
|
||||||
let stride = Vec::<usize>::try_from(args.remove(3))?;
|
let stride = Vec::<usize>::try_from(args.remove(3))?;
|
||||||
let size = Vec::<usize>::try_from(args.remove(2))?;
|
let size = Vec::<usize>::try_from(args.remove(2))?;
|
||||||
let offset = args.remove(1).int()? as usize;
|
let offset = args.remove(1).int_or_long()? as usize;
|
||||||
let storage = args.remove(0).persistent_load()?;
|
let storage = args.remove(0).persistent_load()?;
|
||||||
let mut storage = storage.tuple()?;
|
let mut storage = storage.tuple()?;
|
||||||
let storage_size = storage.remove(4).int()? as usize;
|
let storage_size = storage.remove(4).int_or_long()? as usize;
|
||||||
let path = storage.remove(2).unicode()?;
|
let path = storage.remove(2).unicode()?;
|
||||||
let (_module_name, class_name) = storage.remove(1).class()?;
|
let (_module_name, class_name) = storage.remove(1).class()?;
|
||||||
let dtype = match class_name.as_str() {
|
let dtype = match class_name.as_str() {
|
||||||
@ -624,7 +644,11 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
|
|||||||
crate::bail!("unsupported storage type {other}")
|
crate::bail!("unsupported storage type {other}")
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let layout = Layout::new(crate::Shape::from(size), stride, offset);
|
let layout = Layout::new(
|
||||||
|
crate::Shape::from(size),
|
||||||
|
stride,
|
||||||
|
offset * dtype.size_in_bytes(),
|
||||||
|
);
|
||||||
Ok((layout, dtype, path, storage_size))
|
Ok((layout, dtype, path, storage_size))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user