Pickle decoder fix and Long1 opcode addition. (#2824)

* Pickle decoder changes: added Long1 opcode, fixed tensor offset calculation

* Apply rustfmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
Christian Balcom
2025-03-23 03:10:08 -04:00
committed by GitHub
parent 0b24f7f0a4
commit 67b85f79f1

View File

@ -45,6 +45,7 @@ pub enum OpCode {
BinFloat = b'G',
Append = b'a',
Appends = b'e',
Long1 = 0x8a,
}
// Avoid using FromPrimitive so as not to drag another dependency.
@ -84,6 +85,7 @@ impl TryFrom<u8> for OpCode {
b'G' => Ok(Self::BinFloat),
b'a' => Ok(Self::Append),
b'e' => Ok(Self::Appends),
0x8a => Ok(Self::Long1),
value => Err(value),
}
}
@ -106,6 +108,7 @@ pub enum Object {
class_name: String,
},
Int(i32),
Long(i64),
Float(f64),
Unicode(String),
Bool(bool),
@ -170,6 +173,14 @@ impl Object {
}
}
pub fn int_or_long(self) -> OResult<i64> {
match self {
Self::Int(t) => Ok(t as i64),
Self::Long(t) => Ok(t),
_ => Err(self),
}
}
pub fn tuple(self) -> OResult<Vec<Self>> {
match self {
Self::Tuple(t) => Ok(t),
@ -590,6 +601,15 @@ impl Stack {
let obj = self.new_obj(class, args)?;
self.push(obj)
}
OpCode::Long1 => {
let n_bytes = r.read_u8()?;
let mut v = 0;
// Decode the next n bytes in little endian
for i in 0..n_bytes {
v |= (r.read_u8()? as i64) << (i * 8);
}
self.push(Object::Long(v))
}
}
Ok(false)
}
@ -607,10 +627,10 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
let mut args = args.tuple()?;
let stride = Vec::<usize>::try_from(args.remove(3))?;
let size = Vec::<usize>::try_from(args.remove(2))?;
let offset = args.remove(1).int()? as usize;
let offset = args.remove(1).int_or_long()? as usize;
let storage = args.remove(0).persistent_load()?;
let mut storage = storage.tuple()?;
let storage_size = storage.remove(4).int()? as usize;
let storage_size = storage.remove(4).int_or_long()? as usize;
let path = storage.remove(2).unicode()?;
let (_module_name, class_name) = storage.remove(1).class()?;
let dtype = match class_name.as_str() {
@ -624,7 +644,11 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
crate::bail!("unsupported storage type {other}")
}
};
let layout = Layout::new(crate::Shape::from(size), stride, offset);
let layout = Layout::new(
crate::Shape::from(size),
stride,
offset * dtype.size_in_bytes(),
);
Ok((layout, dtype, path, storage_size))
}