diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index 1632cc26..8b13b50b 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -45,6 +45,7 @@ pub enum OpCode { BinFloat = b'G', Append = b'a', Appends = b'e', + Long1 = 0x8a, } // Avoid using FromPrimitive so as not to drag another dependency. @@ -84,6 +85,7 @@ impl TryFrom for OpCode { b'G' => Ok(Self::BinFloat), b'a' => Ok(Self::Append), b'e' => Ok(Self::Appends), + 0x8a => Ok(Self::Long1), value => Err(value), } } @@ -106,6 +108,7 @@ pub enum Object { class_name: String, }, Int(i32), + Long(i64), Float(f64), Unicode(String), Bool(bool), @@ -170,6 +173,14 @@ impl Object { } } + pub fn int_or_long(self) -> OResult { + match self { + Self::Int(t) => Ok(t as i64), + Self::Long(t) => Ok(t), + _ => Err(self), + } + } + pub fn tuple(self) -> OResult> { match self { Self::Tuple(t) => Ok(t), @@ -590,6 +601,15 @@ impl Stack { let obj = self.new_obj(class, args)?; self.push(obj) } + OpCode::Long1 => { + let n_bytes = r.read_u8()?; + let mut v = 0; + // Decode the next n bytes in little endian + for i in 0..n_bytes { + v |= (r.read_u8()? as i64) << (i * 8); + } + self.push(Object::Long(v)) + } } Ok(false) } @@ -607,10 +627,10 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> { let mut args = args.tuple()?; let stride = Vec::::try_from(args.remove(3))?; let size = Vec::::try_from(args.remove(2))?; - let offset = args.remove(1).int()? as usize; + let offset = args.remove(1).int_or_long()? as usize; let storage = args.remove(0).persistent_load()?; let mut storage = storage.tuple()?; - let storage_size = storage.remove(4).int()? as usize; + let storage_size = storage.remove(4).int_or_long()? as usize; let path = storage.remove(2).unicode()?; let (_module_name, class_name) = storage.remove(1).class()?; let dtype = match class_name.as_str() { @@ -624,7 +644,11 @@ fn rebuild_args(args: Object) -> Result<(Layout, DType, String, usize)> { crate::bail!("unsupported storage type {other}") } }; - let layout = Layout::new(crate::Shape::from(size), stride, offset); + let layout = Layout::new( + crate::Shape::from(size), + stride, + offset * dtype.size_in_bytes(), + ); Ok((layout, dtype, path, storage_size)) }