mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Handle start-offset when loading a tensor from a pickle file. (#1546)
This commit is contained in:
@ -703,6 +703,7 @@ impl PthTensors {
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
|
||||
use std::io::Read;
|
||||
let tensor_info = match self.tensor_infos.get(name) {
|
||||
None => return Ok(None),
|
||||
Some(tensor_info) => tensor_info,
|
||||
@ -712,14 +713,21 @@ impl PthTensors {
|
||||
let mut zip = zip::ZipArchive::new(zip_reader)?;
|
||||
let mut reader = zip.by_name(&tensor_info.path)?;
|
||||
|
||||
// Reading the data is a bit tricky as it can be strided, use an offset, etc.
|
||||
// For now only support the basic case.
|
||||
if tensor_info.layout.start_offset() != 0 || !tensor_info.layout.is_contiguous() {
|
||||
// Reading the data is a bit tricky as it can be strided, for now only support the basic
|
||||
// case.
|
||||
if !tensor_info.layout.is_contiguous() {
|
||||
crate::bail!(
|
||||
"cannot retrieve non-contiguous tensors {:?}",
|
||||
tensor_info.layout
|
||||
)
|
||||
}
|
||||
let start_offset = tensor_info.layout.start_offset();
|
||||
if start_offset > 0 {
|
||||
std::io::copy(
|
||||
&mut reader.by_ref().take(start_offset as u64),
|
||||
&mut std::io::sink(),
|
||||
)?;
|
||||
}
|
||||
let tensor = Tensor::from_reader(
|
||||
tensor_info.layout.shape().clone(),
|
||||
tensor_info.dtype,
|
||||
|
Reference in New Issue
Block a user