Pickle decoder fix and Long1 opcode addition. (#2824)

* Pickle decoder changes: added Long1 opcode, fixed tensor offset calculation

* Apply rustfmt.

---------

Co-authored-by: Laurent <laurent.mazare@gmail.com>
This commit is contained in:
Christian Balcom
2025-03-23 03:10:08 -04:00
committed by GitHub
parent 0b24f7f0a4
commit 67b85f79f1

View File

@ -45,6 +45,7 @@ pub enum OpCode {
BinFloat = b'G', BinFloat = b'G',
Append = b'a', Append = b'a',
Appends = b'e', Appends = b'e',
Long1 = 0x8a,
} }
// Avoid using FromPrimitive so as not to drag another dependency. // Avoid using FromPrimitive so as not to drag another dependency.
@ -84,6 +85,7 @@ impl TryFrom<u8> for OpCode {
b'G' => Ok(Self::BinFloat), b'G' => Ok(Self::BinFloat),
b'a' => Ok(Self::Append), b'a' => Ok(Self::Append),
b'e' => Ok(Self::Appends), b'e' => Ok(Self::Appends),
0x8a => Ok(Self::Long1),
value => Err(value), value => Err(value),
} }
} }
@ -106,6 +108,7 @@ pub enum Object {
class_name: String, class_name: String,
}, },
Int(i32), Int(i32),
Long(i64),
Float(f64), Float(f64),
Unicode(String), Unicode(String),
Bool(bool), Bool(bool),
@ -170,6 +173,14 @@ impl Object {
} }
} }
pub fn int_or_long(self) -> OResult<i64> {
match self {
Self::Int(t) => Ok(t as i64),
Self::Long(t) => Ok(t),
_ => Err(self),
}
}
pub fn tuple(self) -> OResult<Vec<Self>> { pub fn tuple(self) -> OResult<Vec<Self>> {
match self { match self {
Self::Tuple(t) => Ok(t), Self::Tuple(t) => Ok(t),
@ -590,6 +601,15 @@ impl Stack {
let obj = self.new_obj(class, args)?; let obj = self.new_obj(class, args)?;
self.push(obj) self.push(obj)
} }
OpCode::Long1 => {
let n_bytes = r.read_u8()?;
let mut v = 0;
// Decode the next n bytes in little endian
for i in 0..n_bytes {
v |= (r.read_u8()? as i64) << (i * 8);
}
self.push(Object::Long(v))
}
} }
Ok(false) Ok(false)
} }
@ -607,10 +627,10 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
let mut args = args.tuple()?; let mut args = args.tuple()?;
let stride = Vec::<usize>::try_from(args.remove(3))?; let stride = Vec::<usize>::try_from(args.remove(3))?;
let size = Vec::<usize>::try_from(args.remove(2))?; let size = Vec::<usize>::try_from(args.remove(2))?;
let offset = args.remove(1).int()? as usize; let offset = args.remove(1).int_or_long()? as usize;
let storage = args.remove(0).persistent_load()?; let storage = args.remove(0).persistent_load()?;
let mut storage = storage.tuple()?; let mut storage = storage.tuple()?;
let storage_size = storage.remove(4).int()? as usize; let storage_size = storage.remove(4).int_or_long()? as usize;
let path = storage.remove(2).unicode()?; let path = storage.remove(2).unicode()?;
let (_module_name, class_name) = storage.remove(1).class()?; let (_module_name, class_name) = storage.remove(1).class()?;
let dtype = match class_name.as_str() { let dtype = match class_name.as_str() {
@ -624,7 +644,11 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> {
crate::bail!("unsupported storage type {other}") crate::bail!("unsupported storage type {other}")
} }
}; };
let layout = Layout::new(crate::Shape::from(size), stride, offset); let layout = Layout::new(
crate::Shape::from(size),
stride,
offset * dtype.size_in_bytes(),
);
Ok((layout, dtype, path, storage_size)) Ok((layout, dtype, path, storage_size))
} }