From 12b2a337f30f023af157b9ae560b53c3c5bd416c Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 8 Jan 2024 09:20:48 +0100 Subject: [PATCH] Handle start-offset when loading a tensor from a pickle file. (#1546) --- candle-core/src/pickle.rs | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/candle-core/src/pickle.rs b/candle-core/src/pickle.rs index 25640d1a..276b30e3 100644 --- a/candle-core/src/pickle.rs +++ b/candle-core/src/pickle.rs @@ -703,6 +703,7 @@ impl PthTensors { } pub fn get(&self, name: &str) -> Result> { + use std::io::Read; let tensor_info = match self.tensor_infos.get(name) { None => return Ok(None), Some(tensor_info) => tensor_info, @@ -712,14 +713,21 @@ impl PthTensors { let mut zip = zip::ZipArchive::new(zip_reader)?; let mut reader = zip.by_name(&tensor_info.path)?; - // Reading the data is a bit tricky as it can be strided, use an offset, etc. - // For now only support the basic case. - if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() { + // Reading the data is a bit tricky as it can be strided, for now only support the basic + // case. + if !tensor_info.layout.is_contiguous() { crate::bail!( "cannot retrieve non-contiguous tensors {:?}", tensor_info.layout ) } + let start_offset = tensor_info.layout.start_offset(); + if start_offset > 0 { + std::io::copy( + &mut reader.by_ref().take(start_offset as u64), + &mut std::io::sink(), + )?; + } let tensor = Tensor::from_reader( tensor_info.layout.shape().clone(), tensor_info.dtype,